提交文件

This commit is contained in:
林榕城最聪明啦~刚满18岁~ 2024-08-23 19:42:44 +08:00
parent faffbfd886
commit db2ff6a3ff
22 changed files with 2755 additions and 0 deletions

199
U-AE/train.py Normal file
View File

@ -0,0 +1,199 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
from tqdm import tqdm # 导入tqdm库
# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义卷积自编码器
class ConvAutoencoder(nn.Module):
def __init__(self):
super(ConvAutoencoder, self).__init__()
# 编码器
self.encoder1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # output: 16 x 1692 x 855
nn.BatchNorm2d(16),
nn.ReLU(True)
)
self.encoder2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # output: 32 x 846 x 428
nn.BatchNorm2d(32),
nn.ReLU(True)
)
self.encoder3 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # output: 64 x 423 x 214
nn.BatchNorm2d(64),
nn.ReLU(True)
)
self.encoder4 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # output: 128 x 212 x 107
nn.BatchNorm2d(128),
nn.ReLU(True)
)
self.encoder5 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # output: 256 x 106 x 54
nn.BatchNorm2d(256),
nn.ReLU(True)
)
# 解码器
self.decoder5 = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 128 x 212 x 107
nn.BatchNorm2d(128),
nn.ReLU(True)
)
self.decoder4 = nn.Sequential(
nn.ConvTranspose2d(256, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 64 x 423 x 214
nn.BatchNorm2d(64),
nn.ReLU(True)
)
self.decoder3 = nn.Sequential(
nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 32 x 846 x 428
nn.BatchNorm2d(32),
nn.ReLU(True)
)
self.decoder2 = nn.Sequential(
nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 16 x 1692 x 855
nn.BatchNorm2d(16),
nn.ReLU(True)
)
self.decoder1 = nn.Sequential(
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 1 x 3384 x 1710
nn.Sigmoid() # 使用Sigmoid以确保输出在[0, 1]范围内
)
def forward(self, x):
# 编码器
enc1 = self.encoder1(x)
enc2 = self.encoder2(enc1)
enc3 = self.encoder3(enc2)
enc4 = self.encoder4(enc3)
enc5 = self.encoder5(enc4)
# 解码器
dec5 = self.decoder5(enc5)
dec4 = self.decoder4(torch.cat([dec5, enc4], dim=1))
dec3 = self.decoder3(torch.cat([dec4, enc3], dim=1))
dec2 = self.decoder2(torch.cat([dec3, enc2], dim=1))
dec1 = self.decoder1(torch.cat([dec2, enc1], dim=1))
return dec1
# 自定义数据集加载器
class CustomImageDataset(Dataset):
def __init__(self, image_dir, label_dir, transform=None):
self.image_dir = image_dir
self.label_dir = label_dir
self.transform = transform
self.image_names = os.listdir(image_dir)
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
img_name = self.image_names[idx]
img_path = os.path.join(self.image_dir, img_name)
image = Image.open(img_path).convert("L")
label_name = img_name # 假设图像名与标签名匹配
label_path = os.path.join(self.label_dir, label_name)
label_image = Image.open(label_path).convert("L")
if self.transform:
image = self.transform(image)
label_image = self.transform(label_image)
return image, label_image
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# IoU计算函数
def compute_iou(pred, target, threshold=0.5):
pred = (pred > threshold).float()
target = (target > threshold).float()
intersection = (pred * target).sum()
union = pred.sum() + target.sum() - intersection
if union == 0:
return 1.0 if intersection == 0 else 0.0
iou = intersection / union
return iou.item()
# 图像预处理和数据加载
transform = transforms.Compose([
transforms.Resize((1728, 3392)),
transforms.ToTensor()
])
def train():
for epoch in range(num_epochs):
running_loss = 0.0
total_iou = 0.0
progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
for data in progress_bar:
imgs, label_imgs = data
imgs, label_imgs = imgs.to(device), label_imgs.to(device) # 将数据移动到GPU
# 前向传播
output = model(imgs)
loss = criterion(output, label_imgs)
# 计算IoU
iou = compute_iou(output, label_imgs)
total_iou += iou
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
# 更新进度条描述
progress_bar.set_postfix(loss=loss.item(), iou=iou)
# 打印损失和IoU
epoch_loss = running_loss / len(data_loader)
epoch_iou = total_iou / len(data_loader)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}, IoU: {epoch_iou}')
# 保存最佳IoU模型
if epoch_iou > best_iou:
best_iou = epoch_iou
torch.save(model.state_dict(), best_model_path)
print(f'Best model saved with IoU: {best_iou}')
# 每5个epoch保存一次模型权重
if (epoch + 1) % 5 == 0:
torch.save(model.state_dict(), f'out_weights/conv_autoencoder_epoch_{epoch+1}.pth')
print(f'Model weights saved at epoch {epoch+1}')
if __name__ == '__main__':
image_dir = './img' # 替换为你的图像文件夹路径
label_dir = './label' # 替换为你的标签文件夹路径
dataset = CustomImageDataset(image_dir, label_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=15, shuffle=True)
# 实例化模型、定义损失函数和优化器
model = ConvAutoencoder().to(device) # 将模型移动到GPU
print(f'The model has {count_parameters(model):,} trainable parameters')
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# 保存最佳IoU模型
best_iou = 0.0
best_model_path = 'out_weights/best_conv_autoencoder.pth'
# 训练卷积自编码器
num_epochs = 100
train()

View File

@ -0,0 +1,107 @@
import os
import random
import uuid
import cv2
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
from tqdm import tqdm
def trans_255_1(image):
if len(image.shape) == 2: # 单通道图像
mask = image == 255
image[mask] = 1
elif len(image.shape) == 3 and image.shape[2] == 3: # 三通道图像
white = np.array([255, 255, 255])
mask = np.all(image == white, axis=-1)
image[mask] = [1, 1, 1]
else:
raise ValueError("Unsupported image format!")
return image
def resize_image_and_mask(image, mask, target_size):
resized_image = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
resized_mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
return resized_image, resized_mask
def process_image(image_file, image_dir, mask_dir, output_mask_dir, output_image_dir, target_size):
image_path = os.path.join(image_dir, image_file)
mask_path = os.path.join(mask_dir, image_file.rsplit('.', 1)[0] + '_bin.png')
image = cv2.imread(image_path)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if image is None or mask is None:
raise ValueError(f"Image or mask not found for {image_file}")
resized_image, resized_mask = resize_image_and_mask(image, mask, target_size)
resized_mask = trans_255_1(resized_mask)
unique_id = str(uuid.uuid4())
mask_output_path = os.path.join(output_mask_dir, unique_id + '.png')
image_output_path = os.path.join(output_image_dir, unique_id + '.jpg')
cv2.imwrite(mask_output_path, resized_mask)
cv2.imwrite(image_output_path, resized_image)
return unique_id
def mask_to_unet(image_dir, mask_dir, output_mask_dir, output_image_dir, train_txt, val_txt, target_size,
num_images=2000):
image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if len(image_files) < num_images:
raise ValueError(f"Not enough images in directory to sample {num_images} images.")
# 随机选择 num_images 个文件
image_files = random.sample(image_files, num_images)
args = [(image_file, image_dir, mask_dir, output_mask_dir, output_image_dir, target_size) for image_file in
image_files]
results = []
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
futures = [executor.submit(process_image, *arg) for arg in args]
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing images"):
try:
unique_id = future.result()
results.append(unique_id)
except Exception as e:
print(f"Error processing image: {e}")
random.shuffle(results)
split_point = int(len(results) * 0.8)
train_ids, val_ids = results[:split_point], results[split_point:]
with open(train_txt, 'a') as train_file:
for uid in train_ids:
train_file.write(f"{uid}\n")
with open(val_txt, 'a') as val_file:
for uid in val_ids:
val_file.write(f"{uid}\n")
print(f"训练集文件名已写入 {train_txt}")
print(f"验证集文件名已写入 {val_txt}")
if __name__ == "__main__":
target_size = (1696, 864)
image_dir = r"E:\git\unet_seg\unet\VOCdevkit\VOC2007\original_data\dataset_A\train\img"
mask_dir = r"E:\git\unet_seg\unet\VOCdevkit\VOC2007\original_data\dataset_A\train\label"
output_mask_dir = "SegmentationClass"
output_image_dir = "JPEGImages"
output_txt_dir = './ImageSets/Segmentation'
train_txt = os.path.join(output_txt_dir, 'train.txt')
val_txt = os.path.join(output_txt_dir, 'val.txt')
os.makedirs(output_mask_dir, exist_ok=True)
os.makedirs(output_image_dir, exist_ok=True)
os.makedirs(output_txt_dir, exist_ok=True)
mask_to_unet(image_dir, mask_dir, output_mask_dir, output_image_dir, train_txt, val_txt, target_size,
num_images=len(os.listdir(image_dir)))

75
nets/U_ConvAutoencoder.py Normal file
View File

@ -0,0 +1,75 @@
import torch
import torch.nn as nn
class U_ConvAutoencoder(nn.Module):
def __init__(self):
super(U_ConvAutoencoder, self).__init__()
# 编码器
self.encoder1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # output: 16 x 1692 x 855
nn.BatchNorm2d(16),
nn.ReLU(True)
)
self.encoder2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # output: 32 x 846 x 428
nn.BatchNorm2d(32),
nn.ReLU(True)
)
self.encoder3 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # output: 64 x 423 x 214
nn.BatchNorm2d(64),
nn.ReLU(True)
)
self.encoder4 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # output: 128 x 212 x 107
nn.BatchNorm2d(128),
nn.ReLU(True)
)
self.encoder5 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # output: 256 x 106 x 54
nn.BatchNorm2d(256),
nn.ReLU(True)
)
# 解码器
self.decoder5 = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
# output: 128 x 212 x 107
nn.BatchNorm2d(128),
nn.ReLU(True)
)
self.decoder4 = nn.Sequential(
nn.ConvTranspose2d(256, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 64 x 423 x 214
nn.BatchNorm2d(64),
nn.ReLU(True)
)
self.decoder3 = nn.Sequential(
nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 32 x 846 x 428
nn.BatchNorm2d(32),
nn.ReLU(True)
)
self.decoder2 = nn.Sequential(
nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 16 x 1692 x 855
nn.BatchNorm2d(16),
nn.ReLU(True)
)
self.decoder1 = nn.Sequential(
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 1 x 3384 x 1710
nn.Sigmoid() # 使用Sigmoid以确保输出在[0, 1]范围内
)
def forward(self, x):
# 编码器
enc1 = self.encoder1(x)
enc2 = self.encoder2(enc1)
enc3 = self.encoder3(enc2)
enc4 = self.encoder4(enc3)
enc5 = self.encoder5(enc4)
# 解码器
dec5 = self.decoder5(enc5)
dec4 = self.decoder4(torch.cat([dec5, enc4], dim=1))
dec3 = self.decoder3(torch.cat([dec4, enc3], dim=1))
dec2 = self.decoder2(torch.cat([dec3, enc2], dim=1))
dec1 = self.decoder1(torch.cat([dec2, enc1], dim=1))
return dec1

1
nets/__init__.py Normal file
View File

@ -0,0 +1 @@
#

185
nets/resnet.py Normal file
View File

@ -0,0 +1,185 @@
import math
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# 利用1x1卷积下降通道数
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
# 利用3x3卷积进行特征提取
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
# 利用1x1卷积上升通道数
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
#-----------------------------------------------------------#
# 假设输入图像为600,600,3
# 当我们使用resnet50的时候
#-----------------------------------------------------------#
self.inplanes = 64
super(ResNet, self).__init__()
# 600,600,3 -> 300,300,64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
# 300,300,64 -> 150,150,64
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
# 150,150,64 -> 150,150,256
self.layer1 = self._make_layer(block, 64, layers[0])
# 150,150,256 -> 75,75,512
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
# 75,75,512 -> 38,38,1024
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
# 38,38,1024 -> 19,19,2048
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
# x = self.conv1(x)
# x = self.bn1(x)
# x = self.relu(x)
# x = self.maxpool(x)
# x = self.layer1(x)
# x = self.layer2(x)
# x = self.layer3(x)
# x = self.layer4(x)
# x = self.avgpool(x)
# x = x.view(x.size(0), -1)
# x = self.fc(x)
x = self.conv1(x)
x = self.bn1(x)
feat1 = self.relu(x)
x = self.maxpool(feat1)
feat2 = self.layer1(x)
feat3 = self.layer2(feat2)
feat4 = self.layer3(feat3)
feat5 = self.layer4(feat4)
return [feat1, feat2, feat3, feat4, feat5]
def resnet50(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', model_dir='model_data'), strict=False)
del model.avgpool
del model.fc
return model

94
nets/unet.py Normal file
View File

@ -0,0 +1,94 @@
import torch
import torch.nn as nn
from nets.resnet import resnet50
from nets.vgg import VGG16
class unetUp(nn.Module):
def __init__(self, in_size, out_size):
super(unetUp, self).__init__()
self.conv1 = nn.Conv2d(in_size, out_size, kernel_size = 3, padding = 1)
self.conv2 = nn.Conv2d(out_size, out_size, kernel_size = 3, padding = 1)
self.up = nn.UpsamplingBilinear2d(scale_factor = 2)
self.relu = nn.ReLU(inplace = True)
def forward(self, inputs1, inputs2):
outputs = torch.cat([inputs1, self.up(inputs2)], 1)
outputs = self.conv1(outputs)
outputs = self.relu(outputs)
outputs = self.conv2(outputs)
outputs = self.relu(outputs)
return outputs
class Unet(nn.Module):
def __init__(self, num_classes = 21, pretrained = False, backbone = 'vgg'):
super(Unet, self).__init__()
if backbone == 'vgg':
self.vgg = VGG16(pretrained = pretrained)
in_filters = [192, 384, 768, 1024]
elif backbone == "resnet50":
self.resnet = resnet50(pretrained = pretrained)
in_filters = [192, 512, 1024, 3072]
else:
raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone))
out_filters = [64, 128, 256, 512]
# upsampling
# 64,64,512
self.up_concat4 = unetUp(in_filters[3], out_filters[3])
# 128,128,256
self.up_concat3 = unetUp(in_filters[2], out_filters[2])
# 256,256,128
self.up_concat2 = unetUp(in_filters[1], out_filters[1])
# 512,512,64
self.up_concat1 = unetUp(in_filters[0], out_filters[0])
if backbone == 'resnet50':
self.up_conv = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor = 2),
nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1),
nn.ReLU(),
nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1),
nn.ReLU(),
)
else:
self.up_conv = None
self.final = nn.Conv2d(out_filters[0], num_classes, 1)
self.backbone = backbone
def forward(self, inputs):
if self.backbone == "vgg":
[feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs)
elif self.backbone == "resnet50":
[feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs)
up4 = self.up_concat4(feat4, feat5)
up3 = self.up_concat3(feat3, up4)
up2 = self.up_concat2(feat2, up3)
up1 = self.up_concat1(feat1, up2)
if self.up_conv != None:
up1 = self.up_conv(up1)
final = self.final(up1)
return final
def freeze_backbone(self):
if self.backbone == "vgg":
for param in self.vgg.parameters():
param.requires_grad = False
elif self.backbone == "resnet50":
for param in self.resnet.parameters():
param.requires_grad = False
def unfreeze_backbone(self):
if self.backbone == "vgg":
for param in self.vgg.parameters():
param.requires_grad = True
elif self.backbone == "resnet50":
for param in self.resnet.parameters():
param.requires_grad = True

113
nets/unet_training.py Normal file
View File

@ -0,0 +1,113 @@
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
def CE_Loss(inputs, target, cls_weights, num_classes=21):
n, c, h, w = inputs.size()
nt, ht, wt = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
temp_target = target.view(-1)
CE_loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target)
return CE_loss
def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2):
n, c, h, w = inputs.size()
nt, ht, wt = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
temp_target = target.view(-1)
logpt = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target)
pt = torch.exp(logpt)
if alpha is not None:
logpt *= alpha
loss = -((1 - pt) ** gamma) * logpt
loss = loss.mean()
return loss
def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
n, c, h, w = inputs.size()
nt, ht, wt, ct = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
temp_target = target.view(n, -1, ct)
#--------------------------------------------#
# 计算dice loss
#--------------------------------------------#
tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
fp = torch.sum(temp_inputs , axis=[0,1]) - tp
fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
dice_loss = 1 - torch.mean(score)
return dice_loss
def weights_init(net, init_type='normal', init_gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and classname.find('Conv') != -1:
if init_type == 'normal':
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
print('initialize network with %s type' % init_type)
net.apply(init_func)
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
if iters <= warmup_total_iters:
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
elif iters >= total_iters - no_aug_iter:
lr = min_lr
else:
lr = min_lr + 0.5 * (lr - min_lr) * (
1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
)
return lr
def step_lr(lr, decay_rate, step_size, iters):
if step_size < 1:
raise ValueError("step_size must above 1.")
n = iters // step_size
out_lr = lr * decay_rate ** n
return out_lr
if lr_decay_type == "cos":
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
else:
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
step_size = total_iters / step_num
func = partial(step_lr, lr, decay_rate, step_size)
return func
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
lr = lr_scheduler_func(epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr

75
nets/vgg.py Normal file
View File

@ -0,0 +1,75 @@
import torch.nn as nn
from torch.hub import load_state_dict_from_url
class VGG(nn.Module):
def __init__(self, features, num_classes=1000):
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
self._initialize_weights()
def forward(self, x):
# x = self.features(x)
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = self.classifier(x)
feat1 = self.features[ :4 ](x)
feat2 = self.features[4 :9 ](feat1)
feat3 = self.features[9 :16](feat2)
feat4 = self.features[16:23](feat3)
feat5 = self.features[23:-1](feat4)
return [feat1, feat2, feat3, feat4, feat5]
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layers(cfg, batch_norm=False, in_channels = 3):
layers = []
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
# 512,512,3 -> 512,512,64 -> 256,256,64 -> 256,256,128 -> 128,128,128 -> 128,128,256 -> 64,64,256
# 64,64,512 -> 32,32,512 -> 32,32,512
cfgs = {
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
}
def VGG16(pretrained, in_channels = 3, **kwargs):
model = VGG(make_layers(cfgs["D"], batch_norm = False, in_channels = in_channels), **kwargs)
if pretrained:
state_dict = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth", model_dir="./model_data")
model.load_state_dict(state_dict)
del model.avgpool
del model.classifier
return model

169
predicdt.py Normal file
View File

@ -0,0 +1,169 @@
import itertools
import torch
import numpy as np
from torchvision import transforms
from PIL import Image, ImageOps
import cv2
from unet import Unet
from nets.U_ConvAutoencoder import U_ConvAutoencoder
from typing import Tuple, List
# 定义卷积自编码器
class PreCA:
device: torch.device = None
model: U_ConvAutoencoder = None
transform: transforms.Compose = None
@classmethod
def initialize_model(cls, u_ca_path: str) -> None:
# 实例化模型并加载权重
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cls.model = U_ConvAutoencoder().to(cls.device)
cls.model.load_state_dict(torch.load(u_ca_path, map_location=cls.device))
cls.model.eval()
# 图像预处理
cls.transform = transforms.Compose([
transforms.Resize((1728, 3392)),
transforms.ToTensor()
])
@classmethod
def load_image(cls, image: Image.Image) -> torch.Tensor:
image = image.convert("L")
image = cls.transform(image).unsqueeze(0) # 添加batch维度
return image.to(cls.device)
@staticmethod
def ca_smooth(image: Image.Image) -> Image.Image:
image_cv2 = np.array(image)
# 对图像进行闭运算
closed_image = cv2.morphologyEx(image_cv2, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)))
# Step 1: 使用高斯模糊来平滑图像边缘
blurred = cv2.GaussianBlur(closed_image, (1, 1), 0)
th = cv2.threshold(blurred, 126, 255, cv2.THRESH_BINARY)[1]
eroded_image_pil = Image.fromarray(th)
return eroded_image_pil
@classmethod
def infer(cls, image: Image.Image) -> Image.Image:
image = cls.load_image(image)
with torch.no_grad():
output = cls.model(image)
output = output.squeeze(0).cpu() # 去除batch维度并移动到CPU
output_image = transforms.ToPILImage()(output)
output_image = output_image.resize((3384, 1710), Image.NEAREST)
return output_image
class PreUnet:
@staticmethod
def blend_images_with_colorize(image1: Image.Image, image2: Image.Image, alpha: float = 0.5) -> None:
red_image1 = ImageOps.colorize(image1.convert("L"), (0, 0, 0), (255, 0, 0))
green_image2 = ImageOps.colorize(image2.convert("L"), (0, 0, 0), (0, 255, 0))
blended_image = Image.blend(red_image1, green_image2, alpha)
blended_image.show()
@staticmethod
def calculate_metrics(pred_image: Image.Image, true_image: Image.Image, threshold: int = 1) -> Tuple[int, int, int]:
pred_gray = pred_image.convert('L')
true_gray = true_image.convert('L')
pred_binary = pred_gray.point(lambda x: 0 if x < threshold else 255)
true_binary = true_gray.point(lambda x: 0 if x < threshold else 255)
pred_array = np.array(pred_binary)
true_array = np.array(true_binary)
# Calculate TP, FP, FN
TP = np.sum((pred_array == 255) & (true_array == 255))
FP = np.sum((pred_array == 255) & (true_array == 0))
FN = np.sum((pred_array == 0) & (true_array == 255))
return TP, FP, FN
@staticmethod
def apply_mask(original_image, mask_imag):
# 打开原图和mask图片
original_image = original_image.convert("RGB")
mask_image = mask_imag.convert("RGB")
# 获取图片的像素数据
original_pixels = original_image.load()
mask_pixels = mask_image.load()
# 获取图片的尺寸
width, height = original_image.size
# 遍历每个像素
for y in range(height):
for x in range(width):
# 如果mask的像素是白色 (255, 255, 255)
if mask_pixels[x, y] == (255, 255, 255):
# 将原图中的对应像素改为绿色 (0, 255, 0)
original_pixels[x, y] = (0, 255, 0)
# 保存结果图片
return original_image
@classmethod
def main(cls, ca_path: str) -> None:
PreCA.initialize_model(ca_path)
import os
from tqdm import tqdm
ious: List[float] = []
img_names: List[str] = os.listdir(dir_origin_path)
for img_name in tqdm(img_names):
if img_name.lower().endswith(
('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
image_path = os.path.join(dir_origin_path, img_name)
image = Image.open(image_path)
r_image = unet.detect_image(image)
r_image = PreCA.infer(r_image) # 自编码器
r_image = PreCA.ca_smooth(r_image)
if is_save:
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
r_image.save(os.path.join(dir_save_path, img_name.split('.')[0] + '_bin.png'))
if is_get_iou:
label_path = os.path.join(dir_label_path, img_name.split('.')[0] + '_bin.png')
label = Image.open(label_path)
TP, FP, FN = cls.calculate_metrics(r_image, label)
iou = TP / (TP + FP + FN)
ious.append(iou)
print(f"当前iou{iou}")
# cls.blend_images_with_colorize(label, r_image)
if is_get_iou: print(f"平均iou{np.mean(ious)}")
if __name__ == "__main__":
name_classes: List[str] = ["background", "lane"]
dir_origin_path: str = r"E:\git\unet_seg\unet\original_data\dataset_A\test\img"
# 是否计算IOU若为True必须填写dir_label_pathlabel的路径
is_get_iou: bool = True
dir_label_path: str = r"E:\git\unet_seg\unet\original_data\dataset_A\test\Label"
# 是否保存预测后的图像若为True必须填写dir_save_path保存路径的路径
is_save: bool = False
dir_save_path: str = "img_out/"
# 设置多尺度监督自编码器的权重路径
u_ca_path: str = 'weights/best_conv_autoencoder1.pth'
_defaults: dict = {
"model_path": 'model_data/best80.pth', # U-Net权重地址
"num_classes": 2, # 预测类别算上背景为2
"backbone": "vgg",
"input_shape": [1696, 864], # 图像大小
"mix_type": 1,
"cuda": True, # 是否启用cuda加速
}
unet: Unet = Unet(_defaults)
PreUnet.main(u_ca_path)

12
requirements.txt Normal file
View File

@ -0,0 +1,12 @@
matplotlib==3.1.2
numpy==1.21.6
opencv_python==4.1.2.30
Pillow==8.2.0
Pillow==10.4.0
scipy==1.2.1
streamlit==1.23.1
thop==0.1.1.post2209072238
torch
torchsummary
torchvision
tqdm==4.60.0

30
summary.py Normal file
View File

@ -0,0 +1,30 @@
#--------------------------------------------#
# 该部分代码用于看网络结构
#--------------------------------------------#
import torch
from thop import clever_format, profile
from torchsummary import summary
from nets.unet import Unet
if __name__ == "__main__":
input_shape = [1024, 1024]
num_classes = 2
backbone = 'resnet50'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Unet(num_classes = num_classes, backbone = backbone).to(device)
summary(model, (3, input_shape[0], input_shape[1]))
dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
flops, params = profile(model.to(device), (dummy_input, ), verbose=False)
#--------------------------------------------------------#
# flops * 2是因为profile没有将卷积作为两个operations
# 有些论文将卷积算乘法、加法两个operations。此时乘2
# 有些论文只考虑乘法的运算次数忽略加法。此时不乘2
# 本代码选择乘2参考YOLOX。
#--------------------------------------------------------#
flops = flops * 2
flops, params = clever_format([flops, params], "%.3f")
print('Total GFLOPS: %s' % (flops))
print('Total params: %s' % (params))

255
train.py Normal file
View File

@ -0,0 +1,255 @@
import datetime
import os
from functools import partial
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim as optim
from torch.utils.data import DataLoader
from nets.unet import Unet
from nets.unet_training import get_lr_scheduler, set_optimizer_lr, weights_init
from utils.callbacks import EvalCallback, LossHistory
from utils.dataloader import UnetDataset, unet_dataset_collate
from utils.utils import (download_weights, seed_everything, show_config,
worker_init_fn)
from utils.utils_fit import fit_one_epoch
def train():
if distributed:
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
device = torch.device("cuda", local_rank)
if local_rank == 0:
print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...")
print("Gpu Device Count : ", ngpus_per_node)
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
local_rank = 0
rank = 0
if pretrained:
if distributed:
if local_rank == 0:
download_weights(backbone)
dist.barrier()
else:
download_weights(backbone)
model = Unet(num_classes=num_classes, pretrained=pretrained, backbone=backbone).train()
if not pretrained:
weights_init(model)
if model_path != '':
if local_rank == 0:
print('Load weights {}.'.format(model_path))
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
load_key, no_load_key, temp_dict = [], [], {}
for k, v in pretrained_dict.items():
if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
temp_dict[k] = v
load_key.append(k)
else:
no_load_key.append(k)
model_dict.update(temp_dict)
model.load_state_dict(model_dict)
if local_rank == 0:
print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key))
print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key))
if local_rank == 0:
time_str = datetime.datetime.strftime(datetime.datetime.now(), '%Y_%m_%d_%H_%M_%S')
log_dir = os.path.join(save_dir, "loss_" + str(time_str))
loss_history = LossHistory(log_dir, model, input_shape=input_shape)
else:
loss_history = None
if fp16:
from torch.cuda.amp import GradScaler as GradScaler
scaler = GradScaler()
else:
scaler = None
model_train = model.train()
if sync_bn and ngpus_per_node > 1 and distributed:
model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train)
elif sync_bn:
print("Sync_bn is not support in one gpu or not distributed.")
if Cuda:
if distributed:
model_train = model_train.cuda(local_rank)
model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank],
find_unused_parameters=True)
else:
model_train = torch.nn.DataParallel(model)
cudnn.benchmark = True
model_train = model_train.cuda()
with open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/train.txt"), "r") as f:
train_lines = f.readlines()
with open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"), "r") as f:
val_lines = f.readlines()
num_train = len(train_lines)
num_val = len(val_lines)
if True:
UnFreeze_flag = False
if Freeze_Train:
model.freeze_backbone()
batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size
nbs = 16
lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1
lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4
Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
optimizer = {
'adam': optim.Adam(model.parameters(), Init_lr_fit, betas=(momentum, 0.999), weight_decay=weight_decay),
'sgd': optim.SGD(model.parameters(), Init_lr_fit, momentum=momentum, nesterov=True,
weight_decay=weight_decay)
}[optimizer_type]
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
epoch_step = num_train // batch_size
epoch_step_val = num_val // batch_size
if epoch_step == 0 or epoch_step_val == 0:
raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
train_dataset = UnetDataset(train_lines, input_shape, num_classes, True, VOCdevkit_path)
val_dataset = UnetDataset(val_lines, input_shape, num_classes, False, VOCdevkit_path)
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, )
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, )
batch_size = batch_size // ngpus_per_node
shuffle = False
else:
train_sampler = None
val_sampler = None
shuffle = True
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
pin_memory=True,
drop_last=True, collate_fn=unet_dataset_collate, sampler=train_sampler,
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
pin_memory=True,
drop_last=True, collate_fn=unet_dataset_collate, sampler=val_sampler,
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
if local_rank == 0:
eval_callback = EvalCallback(model, input_shape, num_classes, val_lines, VOCdevkit_path, log_dir, Cuda, \
eval_flag=eval_flag, period=eval_period)
else:
eval_callback = None
for epoch in range(Init_Epoch, UnFreeze_Epoch):
if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train:
batch_size = Unfreeze_batch_size
nbs = 16
lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1
lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4
Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
model.unfreeze_backbone()
epoch_step = num_train // batch_size
epoch_step_val = num_val // batch_size
if epoch_step == 0 or epoch_step_val == 0:
raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
if distributed:
batch_size = batch_size // ngpus_per_node
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
pin_memory=True,
drop_last=True, collate_fn=unet_dataset_collate, sampler=train_sampler,
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
pin_memory=True,
drop_last=True, collate_fn=unet_dataset_collate, sampler=val_sampler,
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
UnFreeze_flag = True
if distributed:
train_sampler.set_epoch(epoch)
set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch,
epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, dice_loss, focal_loss,
cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank)
if distributed:
dist.barrier()
if local_rank == 0:
loss_history.writer.close()
if __name__ == "__main__":
Cuda = True
seed = 11
# 是否开启多卡
distributed = False
sync_bn = False
# 是否使用半精度
fp16 = True
num_classes = 2
# 设置骨干网络
backbone = "vgg"
pretrained = False
model_path = "model_data/8414_8376.pth"
input_shape = [1696, 864]
# 冻结训练
Init_Epoch = 0
Freeze_Epoch = 10
Freeze_batch_size = 1
# 解冻训练
UnFreeze_Epoch = 70
Unfreeze_batch_size = 1
Freeze_Train = True
# 学习率设置
Init_lr = 1e-4
Min_lr = Init_lr * 0.01
# 优化器
optimizer_type = "adam"
momentum = 0.9
weight_decay = 0
lr_decay_type = 'cos'
save_period = 5
save_dir = 'logs'
eval_flag = True
eval_period = 5
# 数据集设置
VOCdevkit_path = 'VOCdevkit'
dice_loss = False
focal_loss = False
cls_weights = np.ones([num_classes], np.float32)
num_workers = 0
seed_everything(seed)
ngpus_per_node = torch.cuda.device_count()
train()

131
unet.py Normal file
View File

@ -0,0 +1,131 @@
import colorsys
import copy
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from nets.unet import Unet as unet
from utils.utils import cvtColor, preprocess_input, resize_image
class Unet(object):
_defaults = {
"model_path": None,
"num_classes": 2,
"backbone": "vgg",
"input_shape": [1696, 864],
"mix_type": 1,
"cuda": True,
}
def __init__(self, _defaults,**kwargs):
self._defaults = _defaults
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
if self.num_classes <= 2:
self.colors = [(0, 0, 0), (255,255,255)]
else:
hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
self.generate()
def generate(self, onnx=False):
self.net = unet(num_classes=self.num_classes, backbone=self.backbone)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
self.net = self.net.eval()
print('{} model, and classes loaded.'.format(self.model_path))
if not onnx:
if self.cuda:
self.net = nn.DataParallel(self.net)
self.net = self.net.cuda()
def detect_image(self, image, count=False, name_classes=None):
image = cvtColor(image)
old_img = copy.deepcopy(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
image_data, nw, nh = resize_image(image, (self.input_shape[1], self.input_shape[0]))
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
pr = self.net(images)[0]
pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()
pr = pr[int((self.input_shape[0] - nh) // 2): int((self.input_shape[0] - nh) // 2 + nh), \
int((self.input_shape[1] - nw) // 2): int((self.input_shape[1] - nw) // 2 + nw)]
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation=cv2.INTER_LINEAR)
pr = pr.argmax(axis=-1)
if count:
classes_nums = np.zeros([self.num_classes])
total_points_num = orininal_h * orininal_w
print('-' * 63)
print("|%25s | %15s | %15s|" % ("Key", "Value", "Ratio"))
print('-' * 63)
for i in range(self.num_classes):
num = np.sum(pr == i)
ratio = num / total_points_num * 100
if num > 0:
print("|%25s | %15s | %14.2f%%|" % (str(name_classes[i]), str(num), ratio))
print('-' * 63)
classes_nums[i] = num
print("classes_nums:", classes_nums)
if self.mix_type == 0:
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
# for c in range(self.num_classes):
# seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
# seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
# seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
# ------------------------------------------------#
# 将新图片转换成Image的形式
# ------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
# ------------------------------------------------#
# 将新图与原图及进行混合
# ------------------------------------------------#
image = Image.blend(old_img, image, 0.7)
elif self.mix_type == 1:
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
# for c in range(self.num_classes):
# seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
# seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
# seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
# ------------------------------------------------#
# 将新图片转换成Image的形式
# ------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
elif self.mix_type == 2:
seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8')
# ------------------------------------------------#
# 将新图片转换成Image的形式
# ------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
return image

1
utils/__init__.py Normal file
View File

@ -0,0 +1 @@
#

210
utils/callbacks.py Normal file
View File

@ -0,0 +1,210 @@
import os
import matplotlib
import torch
import torch.nn.functional as F
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signal
import cv2
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from .utils import cvtColor, preprocess_input, resize_image
from .utils_metrics import compute_mIoU
class LossHistory():
def __init__(self, log_dir, model, input_shape, val_loss_flag=True):
self.log_dir = log_dir
self.val_loss_flag = val_loss_flag
self.losses = []
if self.val_loss_flag:
self.val_loss = []
os.makedirs(self.log_dir)
self.writer = SummaryWriter(self.log_dir)
try:
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
self.writer.add_graph(model, dummy_input)
except:
pass
def append_loss(self, epoch, loss, val_loss = None):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.losses.append(loss)
if self.val_loss_flag:
self.val_loss.append(val_loss)
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
f.write(str(loss))
f.write("\n")
if self.val_loss_flag:
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
f.write(str(val_loss))
f.write("\n")
self.writer.add_scalar('loss', loss, epoch)
if self.val_loss_flag:
self.writer.add_scalar('val_loss', val_loss, epoch)
self.loss_plot()
def loss_plot(self):
iters = range(len(self.losses))
plt.figure()
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
if self.val_loss_flag:
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
try:
if len(self.losses) < 25:
num = 5
else:
num = 15
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
if self.val_loss_flag:
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
except:
pass
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
plt.cla()
plt.close("all")
class EvalCallback():
def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \
miou_out_path=".temp_miou_out", eval_flag=True, period=1):
super(EvalCallback, self).__init__()
self.net = net
self.input_shape = input_shape
self.num_classes = num_classes
self.image_ids = image_ids
self.dataset_path = dataset_path
self.log_dir = log_dir
self.cuda = cuda
self.miou_out_path = miou_out_path
self.eval_flag = eval_flag
self.period = period
self.image_ids = [image_id.split()[0] for image_id in image_ids]
self.mious = [0]
self.epoches = [0]
if self.eval_flag:
with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
f.write(str(0))
f.write("\n")
def get_miou_png(self, image):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
#---------------------------------------------------------#
# 给图像增加灰条实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
pr = self.net(images)[0]
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
#--------------------------------------#
# 将灰条部分截取掉
#--------------------------------------#
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
#---------------------------------------------------#
# 进行图片的resize
#---------------------------------------------------#
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = pr.argmax(axis=-1)
image = Image.fromarray(np.uint8(pr))
return image
def on_epoch_end(self, epoch, model_eval):
if epoch % self.period == 0 and self.eval_flag:
self.net = model_eval
gt_dir = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/")
pred_dir = os.path.join(self.miou_out_path, 'detection-results')
if not os.path.exists(self.miou_out_path):
os.makedirs(self.miou_out_path)
if not os.path.exists(pred_dir):
os.makedirs(pred_dir)
print("Get miou.")
for image_id in tqdm(self.image_ids):
#-------------------------------#
# 从文件中读取图像
#-------------------------------#
image_path = os.path.join(self.dataset_path, "VOC2007/JPEGImages/"+image_id+".jpg")
image = Image.open(image_path)
#------------------------------#
# 获得预测txt
#------------------------------#
image = self.get_miou_png(image)
image.save(os.path.join(pred_dir, image_id + ".png"))
print("Calculate miou.")
_, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数
temp_miou = np.nanmean(IoUs) * 100
self.mious.append(temp_miou)
self.epoches.append(epoch)
with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
f.write(str(temp_miou))
f.write("\n")
plt.figure()
plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou')
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Miou')
plt.title('A Miou Curve')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_miou.png"))
plt.cla()
plt.close("all")
print("Get miou done.")
shutil.rmtree(self.miou_out_path)

149
utils/dataloader.py Normal file
View File

@ -0,0 +1,149 @@
import os
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from utils.utils import cvtColor, preprocess_input
class UnetDataset(Dataset):
def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
super(UnetDataset, self).__init__()
self.annotation_lines = annotation_lines
self.length = len(annotation_lines)
self.input_shape = input_shape
self.num_classes = num_classes
self.train = train
self.dataset_path = dataset_path
def __len__(self):
return self.length
def __getitem__(self, index):
annotation_line = self.annotation_lines[index]
name = annotation_line.split()[0]
#-------------------------------#
# 从文件中读取图像
#-------------------------------#
jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg"))
png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png"))
#-------------------------------#
# 数据增强
#-------------------------------#
jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)
jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
png = np.array(png)
png[png >= self.num_classes] = self.num_classes
#-------------------------------------------------------#
# 转化成one_hot的形式
# 在这里需要+1是因为voc数据集有些标签具有白边部分
# 我们需要将白边部分进行忽略,+1的目的是方便忽略。
#-------------------------------------------------------#
seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])]
seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
return jpg, png, seg_labels
def rand(self, a=0, b=1):
return np.random.rand() * (b - a) + a
def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True):
image = cvtColor(image)
label = Image.fromarray(np.array(label))
#------------------------------#
# 获得图像的高宽与目标高宽
#------------------------------#
iw, ih = image.size
h, w = input_shape
if not random:
iw, ih = image.size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', [w, h], (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
label = label.resize((nw,nh), Image.NEAREST)
new_label = Image.new('L', [w, h], (0))
new_label.paste(label, ((w-nw)//2, (h-nh)//2))
return new_image, new_label
#------------------------------------------#
# 对图像进行缩放并且进行长和宽的扭曲
#------------------------------------------#
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
scale = self.rand(0.25, 2)
if new_ar < 1:
nh = int(scale*h)
nw = int(nh*new_ar)
else:
nw = int(scale*w)
nh = int(nw/new_ar)
image = image.resize((nw,nh), Image.BICUBIC)
label = label.resize((nw,nh), Image.NEAREST)
#------------------------------------------#
# 翻转图像
#------------------------------------------#
flip = self.rand()<.5
if flip:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
label = label.transpose(Image.FLIP_LEFT_RIGHT)
#------------------------------------------#
# 将图像多余的部分加上灰条
#------------------------------------------#
dx = int(self.rand(0, w-nw))
dy = int(self.rand(0, h-nh))
new_image = Image.new('RGB', (w,h), (128,128,128))
new_label = Image.new('L', (w,h), (0))
new_image.paste(image, (dx, dy))
new_label.paste(label, (dx, dy))
image = new_image
label = new_label
image_data = np.array(image, np.uint8)
#---------------------------------#
# 对图像进行色域变换
# 计算色域变换的参数
#---------------------------------#
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
#---------------------------------#
# 将图像转到HSV上
#---------------------------------#
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
dtype = image_data.dtype
#---------------------------------#
# 应用变换
#---------------------------------#
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
return image_data, label
# DataLoader中collate_fn使用
def unet_dataset_collate(batch):
images = []
pngs = []
seg_labels = []
for img, png, labels in batch:
images.append(img)
pngs.append(png)
seg_labels.append(labels)
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
pngs = torch.from_numpy(np.array(pngs)).long()
seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
return images, pngs, seg_labels

150
utils/dataloader_medical.py Normal file
View File

@ -0,0 +1,150 @@
import os
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from utils.utils import cvtColor, preprocess_input
class UnetDataset(Dataset):
def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
super(UnetDataset, self).__init__()
self.annotation_lines = annotation_lines
self.length = len(annotation_lines)
self.input_shape = input_shape
self.num_classes = num_classes
self.train = train
self.dataset_path = dataset_path
def __len__(self):
return self.length
def __getitem__(self, index):
annotation_line = self.annotation_lines[index]
name = annotation_line.split()[0]
#-------------------------------#
# 从文件中读取图像
#-------------------------------#
jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "Images"), name + ".png"))
png = Image.open(os.path.join(os.path.join(self.dataset_path, "Labels"), name + ".png"))
#-------------------------------#
# 数据增强
#-------------------------------#
jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)
jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
png = np.array(png)
#-------------------------------------------------------#
# 这里的标签处理方式和普通voc的处理方式不同
# 将小于127.5的像素点设置为目标像素点。
#-------------------------------------------------------#
modify_png = np.zeros_like(png)
modify_png[png <= 127.5] = 1
seg_labels = modify_png
seg_labels = np.eye(self.num_classes + 1)[seg_labels.reshape([-1])]
seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
return jpg, modify_png, seg_labels
def rand(self, a=0, b=1):
return np.random.rand() * (b - a) + a
def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True):
image = cvtColor(image)
label = Image.fromarray(np.array(label))
#------------------------------#
# 获得图像的高宽与目标高宽
#------------------------------#
iw, ih = image.size
h, w = input_shape
if not random:
iw, ih = image.size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', [w, h], (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
label = label.resize((nw,nh), Image.NEAREST)
new_label = Image.new('L', [w, h], (0))
new_label.paste(label, ((w-nw)//2, (h-nh)//2))
return new_image, new_label
#------------------------------------------#
# 对图像进行缩放并且进行长和宽的扭曲
#------------------------------------------#
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
scale = self.rand(0.25, 2)
if new_ar < 1:
nh = int(scale*h)
nw = int(nh*new_ar)
else:
nw = int(scale*w)
nh = int(nw/new_ar)
image = image.resize((nw,nh), Image.BICUBIC)
label = label.resize((nw,nh), Image.NEAREST)
#------------------------------------------#
# 翻转图像
#------------------------------------------#
flip = self.rand()<.5
if flip:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
label = label.transpose(Image.FLIP_LEFT_RIGHT)
#------------------------------------------#
# 将图像多余的部分加上灰条
#------------------------------------------#
dx = int(self.rand(0, w-nw))
dy = int(self.rand(0, h-nh))
new_image = Image.new('RGB', (w,h), (128,128,128))
new_label = Image.new('L', (w,h), (0))
new_image.paste(image, (dx, dy))
new_label.paste(label, (dx, dy))
image = new_image
label = new_label
image_data = np.array(image, np.uint8)
#---------------------------------#
# 对图像进行色域变换
# 计算色域变换的参数
#---------------------------------#
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
#---------------------------------#
# 将图像转到HSV上
#---------------------------------#
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
dtype = image_data.dtype
#---------------------------------#
# 应用变换
#---------------------------------#
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
return image_data, label
# DataLoader中collate_fn使用
def unet_dataset_collate(batch):
images = []
pngs = []
seg_labels = []
for img, png, labels in batch:
images.append(img)
pngs.append(png)
seg_labels.append(labels)
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
pngs = torch.from_numpy(np.array(pngs)).long()
seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
return images, pngs, seg_labels

76
utils/utils.py Normal file
View File

@ -0,0 +1,76 @@
import random
import numpy as np
import torch
from PIL import Image
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
def resize_image(image, size):
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image, nw, nh
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def seed_everything(seed=11):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def worker_init_fn(worker_id, rank, seed):
worker_seed = rank + seed
random.seed(worker_seed)
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
def preprocess_input(image):
image /= 255.0
return image
def show_config(**kwargs):
print('Configurations:')
print('-' * 70)
print('|%25s | %40s|' % ('keys', 'values'))
print('-' * 70)
for key, value in kwargs.items():
print('|%25s | %40s|' % (str(key), str(value)))
print('-' * 70)
def download_weights(backbone, model_dir="./model_data"):
import os
from torch.hub import load_state_dict_from_url
download_urls = {
'vgg' : 'https://download.pytorch.org/models/vgg16-397923af.pth',
'resnet50' : 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth'
}
url = download_urls[backbone]
if not os.path.exists(model_dir):
os.makedirs(model_dir)
load_state_dict_from_url(url, model_dir)

272
utils/utils_fit.py Normal file
View File

@ -0,0 +1,272 @@
import os
import torch
from nets.unet_training import CE_Loss, Dice_loss, Focal_Loss
from tqdm import tqdm
from utils.utils import get_lr
from utils.utils_metrics import f_score
def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0):
total_loss = 0
total_f_score = 0
val_loss = 0
val_f_score = 0
if local_rank == 0:
print('Start Train')
pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
model_train.train()
for iteration, batch in enumerate(gen):
if iteration >= epoch_step:
break
imgs, pngs, labels = batch
with torch.no_grad():
weights = torch.from_numpy(cls_weights)
if cuda:
imgs = imgs.cuda(local_rank)
pngs = pngs.cuda(local_rank)
labels = labels.cuda(local_rank)
weights = weights.cuda(local_rank)
optimizer.zero_grad()
if not fp16:
#----------------------#
# 前向传播
#----------------------#
outputs = model_train(imgs)
#----------------------#
# 损失计算
#----------------------#
if focal_loss:
loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
else:
loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
if dice_loss:
main_dice = Dice_loss(outputs, labels)
loss = loss + main_dice
with torch.no_grad():
#-------------------------------#
# 计算f_score
#-------------------------------#
_f_score = f_score(outputs, labels)
loss.backward()
optimizer.step()
else:
from torch.cuda.amp import autocast
with autocast():
#----------------------#
# 前向传播
#----------------------#
outputs = model_train(imgs)
#----------------------#
# 损失计算
#----------------------#
if focal_loss:
loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
else:
loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
if dice_loss:
main_dice = Dice_loss(outputs, labels)
loss = loss + main_dice
with torch.no_grad():
#-------------------------------#
# 计算f_score
#-------------------------------#
_f_score = f_score(outputs, labels)
#----------------------#
# 反向传播
#----------------------#
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
total_f_score += _f_score.item()
if local_rank == 0:
pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1),
'f_score' : total_f_score / (iteration + 1),
'lr' : get_lr(optimizer)})
pbar.update(1)
if local_rank == 0:
pbar.close()
print('Finish Train')
print('Start Validation')
pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
model_train.eval()
for iteration, batch in enumerate(gen_val):
if iteration >= epoch_step_val:
break
imgs, pngs, labels = batch
with torch.no_grad():
weights = torch.from_numpy(cls_weights)
if cuda:
imgs = imgs.cuda(local_rank)
pngs = pngs.cuda(local_rank)
labels = labels.cuda(local_rank)
weights = weights.cuda(local_rank)
#----------------------#
# 前向传播
#----------------------#
outputs = model_train(imgs)
#----------------------#
# 损失计算
#----------------------#
if focal_loss:
loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
else:
loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
if dice_loss:
main_dice = Dice_loss(outputs, labels)
loss = loss + main_dice
#-------------------------------#
# 计算f_score
#-------------------------------#
_f_score = f_score(outputs, labels)
val_loss += loss.item()
val_f_score += _f_score.item()
if local_rank == 0:
pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1),
'f_score' : val_f_score / (iteration + 1),
'lr' : get_lr(optimizer)})
pbar.update(1)
if local_rank == 0:
pbar.close()
print('Finish Validation')
loss_history.append_loss(epoch + 1, total_loss/ epoch_step, val_loss/ epoch_step_val)
eval_callback.on_epoch_end(epoch + 1, model_train)
print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
#-----------------------------------------------#
# 保存权值
#-----------------------------------------------#
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth'%((epoch + 1), total_loss / epoch_step, val_loss / epoch_step_val)))
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
print('Save best model to best_epoch_weights.pth')
torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))
def fit_one_epoch_no_val(model_train, model, loss_history, optimizer, epoch, epoch_step, gen, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0):
total_loss = 0
total_f_score = 0
if local_rank == 0:
print('Start Train')
pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
model_train.train()
for iteration, batch in enumerate(gen):
if iteration >= epoch_step:
break
imgs, pngs, labels = batch
with torch.no_grad():
weights = torch.from_numpy(cls_weights)
if cuda:
imgs = imgs.cuda(local_rank)
pngs = pngs.cuda(local_rank)
labels = labels.cuda(local_rank)
weights = weights.cuda(local_rank)
optimizer.zero_grad()
if not fp16:
#----------------------#
# 前向传播
#----------------------#
outputs = model_train(imgs)
#----------------------#
# 损失计算
#----------------------#
if focal_loss:
loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
else:
loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
if dice_loss:
main_dice = Dice_loss(outputs, labels)
loss = loss + main_dice
with torch.no_grad():
#-------------------------------#
# 计算f_score
#-------------------------------#
_f_score = f_score(outputs, labels)
loss.backward()
optimizer.step()
else:
from torch.cuda.amp import autocast
with autocast():
#----------------------#
# 前向传播
#----------------------#
outputs = model_train(imgs)
#----------------------#
# 损失计算
#----------------------#
if focal_loss:
loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
else:
loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
if dice_loss:
main_dice = Dice_loss(outputs, labels)
loss = loss + main_dice
with torch.no_grad():
#-------------------------------#
# 计算f_score
#-------------------------------#
_f_score = f_score(outputs, labels)
#----------------------#
# 反向传播
#----------------------#
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
total_f_score += _f_score.item()
if local_rank == 0:
pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1),
'f_score' : total_f_score / (iteration + 1),
'lr' : get_lr(optimizer)})
pbar.update(1)
if local_rank == 0:
pbar.close()
loss_history.append_loss(epoch + 1, total_loss/ epoch_step)
print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
print('Total Loss: %.3f' % (total_loss / epoch_step))
#-----------------------------------------------#
# 保存权值
#-----------------------------------------------#
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f.pth'%((epoch + 1), total_loss / epoch_step)))
if len(loss_history.losses) <= 1 or (total_loss / epoch_step) <= min(loss_history.losses):
print('Save best model to best_epoch_weights.pth')
torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))

182
utils/utils_metrics.py Normal file
View File

@ -0,0 +1,182 @@
import csv
import os
from os.path import join
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5):
n, c, h, w = inputs.size()
nt, ht, wt, ct = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
temp_target = target.view(n, -1, ct)
#--------------------------------------------#
# 计算dice系数
#--------------------------------------------#
temp_inputs = torch.gt(temp_inputs, threhold).float()
tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
fp = torch.sum(temp_inputs , axis=[0,1]) - tp
fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
score = torch.mean(score)
return score
# 设标签宽W长H
def fast_hist(a, b, n):
#--------------------------------------------------------------------------------#
# a是转化成一维数组的标签形状(H×W,)b是转化成一维数组的预测结果形状(H×W,)
#--------------------------------------------------------------------------------#
k = (a >= 0) & (a < n)
#--------------------------------------------------------------------------------#
# np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数返回值形状(n, n)
# 返回中,写对角线上的为分类正确的像素点
#--------------------------------------------------------------------------------#
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
def per_class_iu(hist):
return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1)
def per_class_PA_Recall(hist):
return np.diag(hist) / np.maximum(hist.sum(1), 1)
def per_class_Precision(hist):
return np.diag(hist) / np.maximum(hist.sum(0), 1)
def per_Accuracy(hist):
return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1)
def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None):
print('Num classes', num_classes)
#-----------------------------------------#
# 创建一个全是0的矩阵是一个混淆矩阵
#-----------------------------------------#
hist = np.zeros((num_classes, num_classes))
#------------------------------------------------#
# 获得验证集标签路径列表,方便直接读取
# 获得验证集图像分割结果路径列表,方便直接读取
#------------------------------------------------#
gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list]
pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list]
#------------------------------------------------#
# 读取每一个(图片-标签)对
#------------------------------------------------#
for ind in range(len(gt_imgs)):
#------------------------------------------------#
# 读取一张图像分割结果转化成numpy数组
#------------------------------------------------#
pred = np.array(Image.open(pred_imgs[ind]))
#------------------------------------------------#
# 读取一张对应的标签转化成numpy数组
#------------------------------------------------#
label = np.array(Image.open(gt_imgs[ind]))
# 如果图像分割结果与标签的大小不一样,这张图片就不计算
if len(label.flatten()) != len(pred.flatten()):
print(
'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(
len(label.flatten()), len(pred.flatten()), gt_imgs[ind],
pred_imgs[ind]))
continue
#------------------------------------------------#
# 对一张图片计算21×21的hist矩阵并累加
#------------------------------------------------#
hist += fast_hist(label.flatten(), pred.flatten(), num_classes)
# 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值
if name_classes is not None and ind > 0 and ind % 10 == 0:
print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format(
ind,
len(gt_imgs),
100 * np.nanmean(per_class_iu(hist)),
100 * np.nanmean(per_class_PA_Recall(hist)),
100 * per_Accuracy(hist)
)
)
#------------------------------------------------#
# 计算所有验证集图片的逐类别mIoU值
#------------------------------------------------#
IoUs = per_class_iu(hist)
PA_Recall = per_class_PA_Recall(hist)
Precision = per_class_Precision(hist)
#------------------------------------------------#
# 逐类别输出一下mIoU值
#------------------------------------------------#
if name_classes is not None:
for ind_class in range(num_classes):
print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \
+ '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2)))
#-----------------------------------------------------------------#
# 在所有验证集图像上求所有类别平均的mIoU值计算时忽略NaN值
#-----------------------------------------------------------------#
print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2)))
return np.array(hist, np.int), IoUs, PA_Recall, Precision
def adjust_axes(r, t, fig, axes):
bb = t.get_window_extent(renderer=r)
text_width_inches = bb.width / fig.dpi
current_fig_width = fig.get_figwidth()
new_fig_width = current_fig_width + text_width_inches
propotion = new_fig_width / current_fig_width
x_lim = axes.get_xlim()
axes.set_xlim([x_lim[0], x_lim[1] * propotion])
def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True):
fig = plt.gcf()
axes = plt.gca()
plt.barh(range(len(values)), values, color='royalblue')
plt.title(plot_title, fontsize=tick_font_size + 2)
plt.xlabel(x_label, fontsize=tick_font_size)
plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size)
r = fig.canvas.get_renderer()
for i, val in enumerate(values):
str_val = " " + str(val)
if val < 1.0:
str_val = " {0:.2f}".format(val)
t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold')
if i == (len(values)-1):
adjust_axes(r, t, fig, axes)
fig.tight_layout()
fig.savefig(output_path)
if plt_show:
plt.show()
plt.close()
def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12):
draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \
os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True)
print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png"))
draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Pixel Accuracy", \
os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False)
print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png"))
draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Recall", \
os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False)
print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png"))
draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \
os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False)
print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png"))
with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f:
writer = csv.writer(f)
writer_list = []
writer_list.append([' '] + [str(c) for c in name_classes])
for i in range(len(hist)):
writer_list.append([name_classes[i]] + [str(x) for x in hist[i]])
writer.writerows(writer_list)
print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv"))

191
web.py Normal file
View File

@ -0,0 +1,191 @@
import os
import streamlit as st
import cv2
import tempfile
import torch
import numpy as np
from PIL.Image import Image
from torchvision import transforms
from PIL import Image
from unet import Unet
from nets.U_ConvAutoencoder import U_ConvAutoencoder
from typing import Tuple, List
# Constants and configuration
DEFAULTS = {
"model_path": 'model_data/8414_8376.pth',
"num_classes": 2,
"backbone": "vgg",
"input_shape": [1696, 864],
"mix_type": 1,
"cuda": torch.cuda.is_available(),
}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRANSFORM = transforms.Compose([
transforms.Resize((1728, 3392)),
transforms.ToTensor()
])
class PreCA:
model: U_ConvAutoencoder = None
@classmethod
def initialize_model(cls, u_ca_path: str) -> None:
cls.model = U_ConvAutoencoder().to(DEVICE)
cls.model.load_state_dict(torch.load(u_ca_path, map_location=DEVICE))
cls.model.eval()
@classmethod
def unload_model(cls) -> None:
cls.model = None
torch.cuda.empty_cache()
@classmethod
def load_image(cls, image: Image.Image) -> torch.Tensor:
image = image.convert("L")
image = TRANSFORM(image).unsqueeze(0)
return image.to(DEVICE)
@staticmethod
def ca_smooth(image: Image.Image) -> Image.Image:
image_cv2 = np.array(image)
closed_image = cv2.morphologyEx(image_cv2, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)))
blurred = cv2.GaussianBlur(closed_image, (1, 1), 0)
th = cv2.threshold(blurred, 126, 255, cv2.THRESH_BINARY)[1]
return Image.fromarray(th)
@classmethod
def infer(cls, image: Image.Image) -> Image.Image:
image_tensor = cls.load_image(image)
with torch.no_grad():
output = cls.model(image_tensor)
output = output.squeeze(0).cpu()
output_image = transforms.ToPILImage()(output)
return output_image.resize((3384, 1710), Image.NEAREST)
class PreUnet:
@staticmethod
def calculate_metrics(pred_image: Image.Image, true_image: Image.Image, threshold: int = 1) -> Tuple[int, int, int]:
pred_binary = pred_image.convert('L').point(lambda x: 0 if x < threshold else 255)
true_binary = true_image.convert('L').point(lambda x: 0 if x < threshold else 255)
pred_array = np.array(pred_binary)
true_array = np.array(true_binary)
TP = np.sum((pred_array == 255) & (true_array == 255))
FP = np.sum((pred_array == 255) & (true_array == 0))
FN = np.sum((pred_array == 0) & (true_array == 255))
return TP, FP, FN
@staticmethod
def apply_mask(original_image: Image.Image, mask_image: Image.Image) -> Image.Image:
original_image = original_image.convert("RGB").resize((3384, 1710), Image.NEAREST)
mask_image = mask_image.convert("RGB").resize((3384, 1710), Image.NEAREST)
original_array = np.array(original_image)
mask_array = np.array(mask_image)
mask = np.all(mask_array == [255, 255, 255], axis=-1)
original_array[mask] = [0, 255, 0]
return Image.fromarray(original_array)
@classmethod
def process_image(cls, image: Image.Image, unet):
detected_image = unet.detect_image(image)
inferred_image = PreCA.infer(detected_image)
smoothed_image = PreCA.ca_smooth(inferred_image)
return cls.apply_mask(image, smoothed_image),smoothed_image
def main_page():
st.title('自动驾驶车道线自动检测与增强')
stframe = st.empty()
st.sidebar.subheader("参数设置")
is_pre = st.sidebar.checkbox('开启预测')
unet = Unet(DEFAULTS) if is_pre else None
if is_pre:
u_ca_path = 'weights/best_conv_autoencoder1.pth'
PreCA.initialize_model(u_ca_path)
else:
PreCA.unload_model()
st.sidebar.subheader("图像检测")
image_dir_path = st.sidebar.text_input('请输入图像文件夹路径:')
is_get_iou = st.sidebar.checkbox('开启计算IOU')
label_dir_path = st.sidebar.text_input('请输入标签文件夹路径:') if is_get_iou else None
btn_click = st.sidebar.button("开始预测")
if btn_click:
process_images(image_dir_path, label_dir_path, unet, is_pre, is_get_iou, stframe)
st.sidebar.subheader("视频检测")
uploaded_video = st.sidebar.file_uploader("上传视频:", type=['mp4', 'mov', 'avi', 'mkv', 'flv', 'wmv', 'm4v'])
if uploaded_video is not None:
process_video(uploaded_video, unet, is_pre, stframe)
def process_images(image_dir_path, label_dir_path, unet, is_pre, is_get_iou, stframe):
ious = []
img_names = os.listdir(image_dir_path)
iou_text = st.empty()
for img_name in img_names:
if img_name.lower().endswith(
('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
image_path = os.path.join(image_dir_path, img_name)
image = Image.open(image_path)
if is_pre:
result_image,smoothed_image = PreUnet.process_image(image, unet)
stframe.image([image, result_image], width=640)
if is_get_iou and label_dir_path:
label_path = os.path.join(label_dir_path, f"{os.path.splitext(img_name)[0]}_bin.png")
label = Image.open(label_path)
TP, FP, FN = PreUnet.calculate_metrics(smoothed_image, label)
iou = TP / (TP + FP + FN)
# ious.append(iou)
iou_text.text(f'当前IOU: {iou}')
else:
stframe.image(image, width=1024)
def process_video(uploaded_video, unet, is_pre, stframe):
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_video.read())
tfile.close()
cap = cv2.VideoCapture(tfile.name)
if 'frame_pos' not in st.session_state:
st.session_state.frame_pos = 0
cap.set(cv2.CAP_PROP_POS_FRAMES, st.session_state.frame_pos)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
st.session_state.frame_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
if is_pre:
processed_frame,smoothed_image = PreUnet.process_image(frame, unet)
stframe.image(processed_frame, width=1024,use_column_width=False)
else:
stframe.image(frame, width=1024,use_column_width=False)
cap.release()
if __name__ == '__main__':
main_page()

78
图片修改.py Normal file
View File

@ -0,0 +1,78 @@
import os
import cv2
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
def remove_foreground_and_fill(image_path, mask_path, crop_size):
# 读取图像和掩模
image = cv2.imread(image_path)
mask = cv2.imread(mask_path, 0) # 假设掩模是灰度图
# 确保掩模是二值化的
_, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
# 随机生成裁剪区域的起始点
h, w = mask.shape
start_x = np.random.randint(0, w - crop_size[1] + 1)
start_y = np.random.randint(0, h - crop_size[0] + 1)
# 裁剪掩模和图像
cropped_mask = mask[start_y:start_y + crop_size[0], start_x:start_x + crop_size[1]]
cropped_image = image[start_y:start_y + crop_size[0], start_x:start_x + crop_size[1]]
# 使用inpaint方法进行前景消除并填补
result = cv2.inpaint(cropped_image, cropped_mask, 3, cv2.INPAINT_TELEA)
# 将填补后的图像放回原图的相应位置
result_image = image.copy()
result_image[start_y:start_y + crop_size[0], start_x:start_x + crop_size[1]] = result
return result_image
def process_image(img_name, img_folder, label_folder, save_dir):
if img_name.lower().endswith('.jpg'):
label_name = img_name[:-4] + '_bin.png' # 构造标签文件名
if label_name in os.listdir(label_folder): # 确保标签文件存在
# 调用remove_foreground_and_fill函数处理图像和掩模
result_image = remove_foreground_and_fill(
os.path.join(img_folder, img_name),
os.path.join(label_folder, label_name),
(500, 1710) # 定义裁剪尺寸为500x1710像素
)
# 保存处理后的图像到save_files文件夹
save_path = os.path.join(save_dir, img_name)
cv2.imwrite(save_path, result_image)
def write_img_label_txt(base_dir, dataset_type):
# 创建保存txt文件的目录
save_dir = os.path.join(base_dir, 'save_files')
os.makedirs(save_dir, exist_ok=True)
# 获取img和label文件夹的路径
img_folder = os.path.join(base_dir, dataset_type, 'img')
label_folder = os.path.join(base_dir, dataset_type, 'label')
img_names = [img_name for img_name in os.listdir(img_folder) if img_name.lower().endswith('.jpg')]
with ThreadPoolExecutor() as executor:
futures = []
for img_name in img_names:
futures.append(executor.submit(process_image, img_name, img_folder, label_folder, save_dir))
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing images"):
future.result()
# 基础目录
base_dir = r'E:\git\unet_seg\unet\original_data\dataset_A'
# 处理test和train文件夹
for dataset_type in ['train']:
write_img_label_txt(base_dir, dataset_type)
print('All images have been processed and saved.')