From db2ff6a3ffdba72d64361c05d0c6f565335fe062 Mon Sep 17 00:00:00 2001 From: lrczcm <2508606977@qq.com> Date: Fri, 23 Aug 2024 19:42:44 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- U-AE/train.py | 199 ++++++++++++++++++++++++ VOCdevkit/VOC2007/transform.py | 107 +++++++++++++ nets/U_ConvAutoencoder.py | 75 +++++++++ nets/__init__.py | 1 + nets/resnet.py | 185 ++++++++++++++++++++++ nets/unet.py | 94 ++++++++++++ nets/unet_training.py | 113 ++++++++++++++ nets/vgg.py | 75 +++++++++ predicdt.py | 169 ++++++++++++++++++++ requirements.txt | 12 ++ summary.py | 30 ++++ train.py | 255 +++++++++++++++++++++++++++++++ unet.py | 131 ++++++++++++++++ utils/__init__.py | 1 + utils/callbacks.py | 210 +++++++++++++++++++++++++ utils/dataloader.py | 149 ++++++++++++++++++ utils/dataloader_medical.py | 150 ++++++++++++++++++ utils/utils.py | 76 +++++++++ utils/utils_fit.py | 272 +++++++++++++++++++++++++++++++++ utils/utils_metrics.py | 182 ++++++++++++++++++++++ web.py | 191 +++++++++++++++++++++++ 图片修改.py | 78 ++++++++++ 22 files changed, 2755 insertions(+) create mode 100644 U-AE/train.py create mode 100644 VOCdevkit/VOC2007/transform.py create mode 100644 nets/U_ConvAutoencoder.py create mode 100644 nets/__init__.py create mode 100644 nets/resnet.py create mode 100644 nets/unet.py create mode 100644 nets/unet_training.py create mode 100644 nets/vgg.py create mode 100644 predicdt.py create mode 100644 requirements.txt create mode 100644 summary.py create mode 100644 train.py create mode 100644 unet.py create mode 100644 utils/__init__.py create mode 100644 utils/callbacks.py create mode 100644 utils/dataloader.py create mode 100644 utils/dataloader_medical.py create mode 100644 utils/utils.py create mode 100644 utils/utils_fit.py create mode 100644 utils/utils_metrics.py create mode 100644 web.py create mode 100644 图片修改.py diff --git a/U-AE/train.py b/U-AE/train.py new file mode 100644 index 0000000..b8c3869 --- /dev/null +++ b/U-AE/train.py @@ -0,0 +1,199 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import transforms +from torch.utils.data import DataLoader, Dataset +from PIL import Image +import os +from tqdm import tqdm # 导入tqdm库 + +# 检查是否有可用的GPU +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# 定义卷积自编码器 +class ConvAutoencoder(nn.Module): + def __init__(self): + super(ConvAutoencoder, self).__init__() + # 编码器 + self.encoder1 = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # output: 16 x 1692 x 855 + nn.BatchNorm2d(16), + nn.ReLU(True) + ) + self.encoder2 = nn.Sequential( + nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # output: 32 x 846 x 428 + nn.BatchNorm2d(32), + nn.ReLU(True) + ) + self.encoder3 = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # output: 64 x 423 x 214 + nn.BatchNorm2d(64), + nn.ReLU(True) + ) + self.encoder4 = nn.Sequential( + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # output: 128 x 212 x 107 + nn.BatchNorm2d(128), + nn.ReLU(True) + ) + self.encoder5 = nn.Sequential( + nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # output: 256 x 106 x 54 + nn.BatchNorm2d(256), + nn.ReLU(True) + ) + + # 解码器 + self.decoder5 = nn.Sequential( + nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 128 x 212 x 107 + nn.BatchNorm2d(128), + nn.ReLU(True) + ) + self.decoder4 = nn.Sequential( + nn.ConvTranspose2d(256, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 64 x 423 x 214 + nn.BatchNorm2d(64), + nn.ReLU(True) + ) + self.decoder3 = nn.Sequential( + nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 32 x 846 x 428 + nn.BatchNorm2d(32), + nn.ReLU(True) + ) + self.decoder2 = nn.Sequential( + nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 16 x 1692 x 855 + nn.BatchNorm2d(16), + nn.ReLU(True) + ) + self.decoder1 = nn.Sequential( + nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 1 x 3384 x 1710 + nn.Sigmoid() # 使用Sigmoid以确保输出在[0, 1]范围内 + ) + + def forward(self, x): + # 编码器 + enc1 = self.encoder1(x) + enc2 = self.encoder2(enc1) + enc3 = self.encoder3(enc2) + enc4 = self.encoder4(enc3) + enc5 = self.encoder5(enc4) + + # 解码器 + dec5 = self.decoder5(enc5) + dec4 = self.decoder4(torch.cat([dec5, enc4], dim=1)) + dec3 = self.decoder3(torch.cat([dec4, enc3], dim=1)) + dec2 = self.decoder2(torch.cat([dec3, enc2], dim=1)) + dec1 = self.decoder1(torch.cat([dec2, enc1], dim=1)) + + return dec1 + + +# 自定义数据集加载器 +class CustomImageDataset(Dataset): + def __init__(self, image_dir, label_dir, transform=None): + self.image_dir = image_dir + self.label_dir = label_dir + self.transform = transform + self.image_names = os.listdir(image_dir) + + def __len__(self): + return len(self.image_names) + + def __getitem__(self, idx): + img_name = self.image_names[idx] + img_path = os.path.join(self.image_dir, img_name) + image = Image.open(img_path).convert("L") + + label_name = img_name # 假设图像名与标签名匹配 + label_path = os.path.join(self.label_dir, label_name) + label_image = Image.open(label_path).convert("L") + + if self.transform: + image = self.transform(image) + label_image = self.transform(label_image) + + return image, label_image +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +# IoU计算函数 +def compute_iou(pred, target, threshold=0.5): + pred = (pred > threshold).float() + target = (target > threshold).float() + + intersection = (pred * target).sum() + union = pred.sum() + target.sum() - intersection + + if union == 0: + return 1.0 if intersection == 0 else 0.0 + + iou = intersection / union + return iou.item() + +# 图像预处理和数据加载 +transform = transforms.Compose([ + transforms.Resize((1728, 3392)), + transforms.ToTensor() +]) + + +def train(): + for epoch in range(num_epochs): + running_loss = 0.0 + total_iou = 0.0 + progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) + for data in progress_bar: + imgs, label_imgs = data + imgs, label_imgs = imgs.to(device), label_imgs.to(device) # 将数据移动到GPU + + # 前向传播 + output = model(imgs) + loss = criterion(output, label_imgs) + + # 计算IoU + iou = compute_iou(output, label_imgs) + total_iou += iou + + # 反向传播和优化 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + + # 更新进度条描述 + progress_bar.set_postfix(loss=loss.item(), iou=iou) + + # 打印损失和IoU + epoch_loss = running_loss / len(data_loader) + epoch_iou = total_iou / len(data_loader) + print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}, IoU: {epoch_iou}') + + # 保存最佳IoU模型 + if epoch_iou > best_iou: + best_iou = epoch_iou + torch.save(model.state_dict(), best_model_path) + print(f'Best model saved with IoU: {best_iou}') + + # 每5个epoch保存一次模型权重 + if (epoch + 1) % 5 == 0: + torch.save(model.state_dict(), f'out_weights/conv_autoencoder_epoch_{epoch+1}.pth') + print(f'Model weights saved at epoch {epoch+1}') + + +if __name__ == '__main__': + image_dir = './img' # 替换为你的图像文件夹路径 + label_dir = './label' # 替换为你的标签文件夹路径 + + dataset = CustomImageDataset(image_dir, label_dir, transform=transform) + data_loader = DataLoader(dataset, batch_size=15, shuffle=True) + + # 实例化模型、定义损失函数和优化器 + model = ConvAutoencoder().to(device) # 将模型移动到GPU + print(f'The model has {count_parameters(model):,} trainable parameters') + criterion = nn.BCELoss() + optimizer = optim.Adam(model.parameters(), lr=0.01) + + # 保存最佳IoU模型 + best_iou = 0.0 + best_model_path = 'out_weights/best_conv_autoencoder.pth' + # 训练卷积自编码器 + num_epochs = 100 + train() \ No newline at end of file diff --git a/VOCdevkit/VOC2007/transform.py b/VOCdevkit/VOC2007/transform.py new file mode 100644 index 0000000..c380ca2 --- /dev/null +++ b/VOCdevkit/VOC2007/transform.py @@ -0,0 +1,107 @@ +import os +import random +import uuid +import cv2 +import numpy as np +from concurrent.futures import ThreadPoolExecutor, as_completed +from multiprocessing import cpu_count +from tqdm import tqdm + + +def trans_255_1(image): + if len(image.shape) == 2: # 单通道图像 + mask = image == 255 + image[mask] = 1 + elif len(image.shape) == 3 and image.shape[2] == 3: # 三通道图像 + white = np.array([255, 255, 255]) + mask = np.all(image == white, axis=-1) + image[mask] = [1, 1, 1] + else: + raise ValueError("Unsupported image format!") + return image + + +def resize_image_and_mask(image, mask, target_size): + resized_image = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA) + resized_mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST) + return resized_image, resized_mask + + +def process_image(image_file, image_dir, mask_dir, output_mask_dir, output_image_dir, target_size): + image_path = os.path.join(image_dir, image_file) + mask_path = os.path.join(mask_dir, image_file.rsplit('.', 1)[0] + '_bin.png') + image = cv2.imread(image_path) + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + + if image is None or mask is None: + raise ValueError(f"Image or mask not found for {image_file}") + + resized_image, resized_mask = resize_image_and_mask(image, mask, target_size) + resized_mask = trans_255_1(resized_mask) + + unique_id = str(uuid.uuid4()) + mask_output_path = os.path.join(output_mask_dir, unique_id + '.png') + image_output_path = os.path.join(output_image_dir, unique_id + '.jpg') + + cv2.imwrite(mask_output_path, resized_mask) + cv2.imwrite(image_output_path, resized_image) + + return unique_id + + +def mask_to_unet(image_dir, mask_dir, output_mask_dir, output_image_dir, train_txt, val_txt, target_size, + num_images=2000): + image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] + + if len(image_files) < num_images: + raise ValueError(f"Not enough images in directory to sample {num_images} images.") + + # 随机选择 num_images 个文件 + image_files = random.sample(image_files, num_images) + + args = [(image_file, image_dir, mask_dir, output_mask_dir, output_image_dir, target_size) for image_file in + image_files] + + results = [] + with ThreadPoolExecutor(max_workers=cpu_count()) as executor: + futures = [executor.submit(process_image, *arg) for arg in args] + for future in tqdm(as_completed(futures), total=len(futures), desc="Processing images"): + try: + unique_id = future.result() + results.append(unique_id) + except Exception as e: + print(f"Error processing image: {e}") + + random.shuffle(results) + split_point = int(len(results) * 0.8) + train_ids, val_ids = results[:split_point], results[split_point:] + + with open(train_txt, 'a') as train_file: + for uid in train_ids: + train_file.write(f"{uid}\n") + + with open(val_txt, 'a') as val_file: + for uid in val_ids: + val_file.write(f"{uid}\n") + + print(f"训练集文件名已写入 {train_txt}") + print(f"验证集文件名已写入 {val_txt}") + + +if __name__ == "__main__": + target_size = (1696, 864) + image_dir = r"E:\git\unet_seg\unet\VOCdevkit\VOC2007\original_data\dataset_A\train\img" + mask_dir = r"E:\git\unet_seg\unet\VOCdevkit\VOC2007\original_data\dataset_A\train\label" + output_mask_dir = "SegmentationClass" + output_image_dir = "JPEGImages" + output_txt_dir = './ImageSets/Segmentation' + + train_txt = os.path.join(output_txt_dir, 'train.txt') + val_txt = os.path.join(output_txt_dir, 'val.txt') + + os.makedirs(output_mask_dir, exist_ok=True) + os.makedirs(output_image_dir, exist_ok=True) + os.makedirs(output_txt_dir, exist_ok=True) + + mask_to_unet(image_dir, mask_dir, output_mask_dir, output_image_dir, train_txt, val_txt, target_size, + num_images=len(os.listdir(image_dir))) \ No newline at end of file diff --git a/nets/U_ConvAutoencoder.py b/nets/U_ConvAutoencoder.py new file mode 100644 index 0000000..6117e71 --- /dev/null +++ b/nets/U_ConvAutoencoder.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +class U_ConvAutoencoder(nn.Module): + def __init__(self): + super(U_ConvAutoencoder, self).__init__() + # 编码器 + self.encoder1 = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # output: 16 x 1692 x 855 + nn.BatchNorm2d(16), + nn.ReLU(True) + ) + self.encoder2 = nn.Sequential( + nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # output: 32 x 846 x 428 + nn.BatchNorm2d(32), + nn.ReLU(True) + ) + self.encoder3 = nn.Sequential( + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # output: 64 x 423 x 214 + nn.BatchNorm2d(64), + nn.ReLU(True) + ) + self.encoder4 = nn.Sequential( + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # output: 128 x 212 x 107 + nn.BatchNorm2d(128), + nn.ReLU(True) + ) + self.encoder5 = nn.Sequential( + nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # output: 256 x 106 x 54 + nn.BatchNorm2d(256), + nn.ReLU(True) + ) + + # 解码器 + self.decoder5 = nn.Sequential( + nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), + # output: 128 x 212 x 107 + nn.BatchNorm2d(128), + nn.ReLU(True) + ) + self.decoder4 = nn.Sequential( + nn.ConvTranspose2d(256, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 64 x 423 x 214 + nn.BatchNorm2d(64), + nn.ReLU(True) + ) + self.decoder3 = nn.Sequential( + nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 32 x 846 x 428 + nn.BatchNorm2d(32), + nn.ReLU(True) + ) + self.decoder2 = nn.Sequential( + nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 16 x 1692 x 855 + nn.BatchNorm2d(16), + nn.ReLU(True) + ) + self.decoder1 = nn.Sequential( + nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 1 x 3384 x 1710 + nn.Sigmoid() # 使用Sigmoid以确保输出在[0, 1]范围内 + ) + + def forward(self, x): + # 编码器 + enc1 = self.encoder1(x) + enc2 = self.encoder2(enc1) + enc3 = self.encoder3(enc2) + enc4 = self.encoder4(enc3) + enc5 = self.encoder5(enc4) + + # 解码器 + dec5 = self.decoder5(enc5) + dec4 = self.decoder4(torch.cat([dec5, enc4], dim=1)) + dec3 = self.decoder3(torch.cat([dec4, enc3], dim=1)) + dec2 = self.decoder2(torch.cat([dec3, enc2], dim=1)) + dec1 = self.decoder1(torch.cat([dec2, enc1], dim=1)) + + return dec1 \ No newline at end of file diff --git a/nets/__init__.py b/nets/__init__.py new file mode 100644 index 0000000..4287ca8 --- /dev/null +++ b/nets/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/nets/resnet.py b/nets/resnet.py new file mode 100644 index 0000000..21ef95a --- /dev/null +++ b/nets/resnet.py @@ -0,0 +1,185 @@ +import math + +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # 利用1x1卷积下降通道数 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + # 利用3x3卷积进行特征提取 + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + # 利用1x1卷积上升通道数 + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers, num_classes=1000): + #-----------------------------------------------------------# + # 假设输入图像为600,600,3 + # 当我们使用resnet50的时候 + #-----------------------------------------------------------# + self.inplanes = 64 + super(ResNet, self).__init__() + # 600,600,3 -> 300,300,64 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + # 300,300,64 -> 150,150,64 + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change + # 150,150,64 -> 150,150,256 + self.layer1 = self._make_layer(block, 64, layers[0]) + # 150,150,256 -> 75,75,512 + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + # 75,75,512 -> 38,38,1024 + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + # 38,38,1024 -> 19,19,2048 + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + self.avgpool = nn.AvgPool2d(7) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + # x = self.conv1(x) + # x = self.bn1(x) + # x = self.relu(x) + # x = self.maxpool(x) + + # x = self.layer1(x) + # x = self.layer2(x) + # x = self.layer3(x) + # x = self.layer4(x) + + # x = self.avgpool(x) + # x = x.view(x.size(0), -1) + # x = self.fc(x) + + x = self.conv1(x) + x = self.bn1(x) + feat1 = self.relu(x) + + x = self.maxpool(feat1) + feat2 = self.layer1(x) + + feat3 = self.layer2(feat2) + feat4 = self.layer3(feat3) + feat5 = self.layer4(feat4) + return [feat1, feat2, feat3, feat4, feat5] + +def resnet50(pretrained=False, **kwargs): + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', model_dir='model_data'), strict=False) + + del model.avgpool + del model.fc + return model diff --git a/nets/unet.py b/nets/unet.py new file mode 100644 index 0000000..71417ee --- /dev/null +++ b/nets/unet.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn + +from nets.resnet import resnet50 +from nets.vgg import VGG16 + + +class unetUp(nn.Module): + def __init__(self, in_size, out_size): + super(unetUp, self).__init__() + self.conv1 = nn.Conv2d(in_size, out_size, kernel_size = 3, padding = 1) + self.conv2 = nn.Conv2d(out_size, out_size, kernel_size = 3, padding = 1) + self.up = nn.UpsamplingBilinear2d(scale_factor = 2) + self.relu = nn.ReLU(inplace = True) + + def forward(self, inputs1, inputs2): + outputs = torch.cat([inputs1, self.up(inputs2)], 1) + outputs = self.conv1(outputs) + outputs = self.relu(outputs) + outputs = self.conv2(outputs) + outputs = self.relu(outputs) + return outputs + +class Unet(nn.Module): + def __init__(self, num_classes = 21, pretrained = False, backbone = 'vgg'): + super(Unet, self).__init__() + if backbone == 'vgg': + self.vgg = VGG16(pretrained = pretrained) + in_filters = [192, 384, 768, 1024] + elif backbone == "resnet50": + self.resnet = resnet50(pretrained = pretrained) + in_filters = [192, 512, 1024, 3072] + else: + raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone)) + out_filters = [64, 128, 256, 512] + + # upsampling + # 64,64,512 + self.up_concat4 = unetUp(in_filters[3], out_filters[3]) + # 128,128,256 + self.up_concat3 = unetUp(in_filters[2], out_filters[2]) + # 256,256,128 + self.up_concat2 = unetUp(in_filters[1], out_filters[1]) + # 512,512,64 + self.up_concat1 = unetUp(in_filters[0], out_filters[0]) + + if backbone == 'resnet50': + self.up_conv = nn.Sequential( + nn.UpsamplingBilinear2d(scale_factor = 2), + nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1), + nn.ReLU(), + nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1), + nn.ReLU(), + ) + else: + self.up_conv = None + + self.final = nn.Conv2d(out_filters[0], num_classes, 1) + + self.backbone = backbone + + def forward(self, inputs): + if self.backbone == "vgg": + [feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs) + elif self.backbone == "resnet50": + [feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs) + + up4 = self.up_concat4(feat4, feat5) + up3 = self.up_concat3(feat3, up4) + up2 = self.up_concat2(feat2, up3) + up1 = self.up_concat1(feat1, up2) + + if self.up_conv != None: + up1 = self.up_conv(up1) + + final = self.final(up1) + + return final + + def freeze_backbone(self): + if self.backbone == "vgg": + for param in self.vgg.parameters(): + param.requires_grad = False + elif self.backbone == "resnet50": + for param in self.resnet.parameters(): + param.requires_grad = False + + def unfreeze_backbone(self): + if self.backbone == "vgg": + for param in self.vgg.parameters(): + param.requires_grad = True + elif self.backbone == "resnet50": + for param in self.resnet.parameters(): + param.requires_grad = True diff --git a/nets/unet_training.py b/nets/unet_training.py new file mode 100644 index 0000000..7210e56 --- /dev/null +++ b/nets/unet_training.py @@ -0,0 +1,113 @@ +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def CE_Loss(inputs, target, cls_weights, num_classes=21): + n, c, h, w = inputs.size() + nt, ht, wt = target.size() + if h != ht and w != wt: + inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) + + temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) + temp_target = target.view(-1) + + CE_loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target) + return CE_loss + +def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2): + n, c, h, w = inputs.size() + nt, ht, wt = target.size() + if h != ht and w != wt: + inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) + + temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) + temp_target = target.view(-1) + + logpt = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target) + pt = torch.exp(logpt) + if alpha is not None: + logpt *= alpha + loss = -((1 - pt) ** gamma) * logpt + loss = loss.mean() + return loss + +def Dice_loss(inputs, target, beta=1, smooth = 1e-5): + n, c, h, w = inputs.size() + nt, ht, wt, ct = target.size() + if h != ht and w != wt: + inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) + + temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1) + temp_target = target.view(n, -1, ct) + + #--------------------------------------------# + # 计算dice loss + #--------------------------------------------# + tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1]) + fp = torch.sum(temp_inputs , axis=[0,1]) - tp + fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp + + score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) + dice_loss = 1 - torch.mean(score) + return dice_loss + +def weights_init(net, init_type='normal', init_gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and classname.find('Conv') != -1: + if init_type == 'normal': + torch.nn.init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + elif classname.find('BatchNorm2d') != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + print('initialize network with %s type' % init_type) + net.apply(init_func) + +def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10): + def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): + if iters <= warmup_total_iters: + # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start + lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start + elif iters >= total_iters - no_aug_iter: + lr = min_lr + else: + lr = min_lr + 0.5 * (lr - min_lr) * ( + 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) + ) + return lr + + def step_lr(lr, decay_rate, step_size, iters): + if step_size < 1: + raise ValueError("step_size must above 1.") + n = iters // step_size + out_lr = lr * decay_rate ** n + return out_lr + + if lr_decay_type == "cos": + warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) + warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) + no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) + func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) + else: + decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) + step_size = total_iters / step_num + func = partial(step_lr, lr, decay_rate, step_size) + + return func + +def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): + lr = lr_scheduler_func(epoch) + for param_group in optimizer.param_groups: + param_group['lr'] = lr diff --git a/nets/vgg.py b/nets/vgg.py new file mode 100644 index 0000000..f44ffd8 --- /dev/null +++ b/nets/vgg.py @@ -0,0 +1,75 @@ +import torch.nn as nn +from torch.hub import load_state_dict_from_url + + +class VGG(nn.Module): + def __init__(self, features, num_classes=1000): + super(VGG, self).__init__() + self.features = features + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + self._initialize_weights() + + def forward(self, x): + # x = self.features(x) + # x = self.avgpool(x) + # x = torch.flatten(x, 1) + # x = self.classifier(x) + feat1 = self.features[ :4 ](x) + feat2 = self.features[4 :9 ](feat1) + feat3 = self.features[9 :16](feat2) + feat4 = self.features[16:23](feat3) + feat5 = self.features[23:-1](feat4) + return [feat1, feat2, feat3, feat4, feat5] + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def make_layers(cfg, batch_norm=False, in_channels = 3): + layers = [] + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) +# 512,512,3 -> 512,512,64 -> 256,256,64 -> 256,256,128 -> 128,128,128 -> 128,128,256 -> 64,64,256 +# 64,64,512 -> 32,32,512 -> 32,32,512 +cfgs = { + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] +} + + +def VGG16(pretrained, in_channels = 3, **kwargs): + model = VGG(make_layers(cfgs["D"], batch_norm = False, in_channels = in_channels), **kwargs) + if pretrained: + state_dict = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth", model_dir="./model_data") + model.load_state_dict(state_dict) + + del model.avgpool + del model.classifier + return model diff --git a/predicdt.py b/predicdt.py new file mode 100644 index 0000000..df5918f --- /dev/null +++ b/predicdt.py @@ -0,0 +1,169 @@ +import itertools +import torch +import numpy as np +from torchvision import transforms +from PIL import Image, ImageOps +import cv2 +from unet import Unet +from nets.U_ConvAutoencoder import U_ConvAutoencoder +from typing import Tuple, List + + +# 定义卷积自编码器 + +class PreCA: + device: torch.device = None + model: U_ConvAutoencoder = None + transform: transforms.Compose = None + + @classmethod + def initialize_model(cls, u_ca_path: str) -> None: + # 实例化模型并加载权重 + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cls.model = U_ConvAutoencoder().to(cls.device) + cls.model.load_state_dict(torch.load(u_ca_path, map_location=cls.device)) + cls.model.eval() + # 图像预处理 + cls.transform = transforms.Compose([ + transforms.Resize((1728, 3392)), + transforms.ToTensor() + ]) + + @classmethod + def load_image(cls, image: Image.Image) -> torch.Tensor: + image = image.convert("L") + image = cls.transform(image).unsqueeze(0) # 添加batch维度 + return image.to(cls.device) + + @staticmethod + def ca_smooth(image: Image.Image) -> Image.Image: + image_cv2 = np.array(image) + # 对图像进行闭运算 + closed_image = cv2.morphologyEx(image_cv2, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))) + # Step 1: 使用高斯模糊来平滑图像边缘 + blurred = cv2.GaussianBlur(closed_image, (1, 1), 0) + th = cv2.threshold(blurred, 126, 255, cv2.THRESH_BINARY)[1] + + eroded_image_pil = Image.fromarray(th) + return eroded_image_pil + + @classmethod + def infer(cls, image: Image.Image) -> Image.Image: + image = cls.load_image(image) + with torch.no_grad(): + output = cls.model(image) + output = output.squeeze(0).cpu() # 去除batch维度并移动到CPU + output_image = transforms.ToPILImage()(output) + output_image = output_image.resize((3384, 1710), Image.NEAREST) + return output_image + + +class PreUnet: + + @staticmethod + def blend_images_with_colorize(image1: Image.Image, image2: Image.Image, alpha: float = 0.5) -> None: + red_image1 = ImageOps.colorize(image1.convert("L"), (0, 0, 0), (255, 0, 0)) + green_image2 = ImageOps.colorize(image2.convert("L"), (0, 0, 0), (0, 255, 0)) + blended_image = Image.blend(red_image1, green_image2, alpha) + blended_image.show() + + @staticmethod + def calculate_metrics(pred_image: Image.Image, true_image: Image.Image, threshold: int = 1) -> Tuple[int, int, int]: + pred_gray = pred_image.convert('L') + true_gray = true_image.convert('L') + + pred_binary = pred_gray.point(lambda x: 0 if x < threshold else 255) + true_binary = true_gray.point(lambda x: 0 if x < threshold else 255) + + pred_array = np.array(pred_binary) + true_array = np.array(true_binary) + + # Calculate TP, FP, FN + TP = np.sum((pred_array == 255) & (true_array == 255)) + FP = np.sum((pred_array == 255) & (true_array == 0)) + FN = np.sum((pred_array == 0) & (true_array == 255)) + + return TP, FP, FN + + @staticmethod + def apply_mask(original_image, mask_imag): + # 打开原图和mask图片 + original_image = original_image.convert("RGB") + mask_image = mask_imag.convert("RGB") + + # 获取图片的像素数据 + original_pixels = original_image.load() + mask_pixels = mask_image.load() + + # 获取图片的尺寸 + width, height = original_image.size + + # 遍历每个像素 + for y in range(height): + for x in range(width): + # 如果mask的像素是白色 (255, 255, 255) + if mask_pixels[x, y] == (255, 255, 255): + # 将原图中的对应像素改为绿色 (0, 255, 0) + original_pixels[x, y] = (0, 255, 0) + + # 保存结果图片 + return original_image + + + + @classmethod + def main(cls, ca_path: str) -> None: + PreCA.initialize_model(ca_path) + import os + from tqdm import tqdm + ious: List[float] = [] + img_names: List[str] = os.listdir(dir_origin_path) + for img_name in tqdm(img_names): + if img_name.lower().endswith( + ('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): + image_path = os.path.join(dir_origin_path, img_name) + image = Image.open(image_path) + + r_image = unet.detect_image(image) + r_image = PreCA.infer(r_image) # 自编码器 + r_image = PreCA.ca_smooth(r_image) + + + if is_save: + if not os.path.exists(dir_save_path): + os.makedirs(dir_save_path) + r_image.save(os.path.join(dir_save_path, img_name.split('.')[0] + '_bin.png')) + if is_get_iou: + label_path = os.path.join(dir_label_path, img_name.split('.')[0] + '_bin.png') + label = Image.open(label_path) + TP, FP, FN = cls.calculate_metrics(r_image, label) + iou = TP / (TP + FP + FN) + ious.append(iou) + print(f"当前iou{iou}") + + # cls.blend_images_with_colorize(label, r_image) + + if is_get_iou: print(f"平均iou{np.mean(ious)}") + + +if __name__ == "__main__": + name_classes: List[str] = ["background", "lane"] + dir_origin_path: str = r"E:\git\unet_seg\unet\original_data\dataset_A\test\img" + # 是否计算IOU,若为True必须填写dir_label_path(label的路径) + is_get_iou: bool = True + dir_label_path: str = r"E:\git\unet_seg\unet\original_data\dataset_A\test\Label" + # 是否保存预测后的图像,若为True必须填写dir_save_path(保存路径的路径) + is_save: bool = False + dir_save_path: str = "img_out/" + # 设置多尺度监督自编码器的权重路径 + u_ca_path: str = 'weights/best_conv_autoencoder1.pth' + _defaults: dict = { + "model_path": 'model_data/best80.pth', # U-Net权重地址 + "num_classes": 2, # 预测类别算上背景为2 + "backbone": "vgg", + "input_shape": [1696, 864], # 图像大小 + "mix_type": 1, + "cuda": True, # 是否启用cuda加速 + } + unet: Unet = Unet(_defaults) + PreUnet.main(u_ca_path) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7da77c1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +matplotlib==3.1.2 +numpy==1.21.6 +opencv_python==4.1.2.30 +Pillow==8.2.0 +Pillow==10.4.0 +scipy==1.2.1 +streamlit==1.23.1 +thop==0.1.1.post2209072238 +torch +torchsummary +torchvision +tqdm==4.60.0 diff --git a/summary.py b/summary.py new file mode 100644 index 0000000..1ab0f91 --- /dev/null +++ b/summary.py @@ -0,0 +1,30 @@ +#--------------------------------------------# +# 该部分代码用于看网络结构 +#--------------------------------------------# +import torch +from thop import clever_format, profile +from torchsummary import summary + +from nets.unet import Unet + +if __name__ == "__main__": + input_shape = [1024, 1024] + num_classes = 2 + backbone = 'resnet50' + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = Unet(num_classes = num_classes, backbone = backbone).to(device) + summary(model, (3, input_shape[0], input_shape[1])) + + dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device) + flops, params = profile(model.to(device), (dummy_input, ), verbose=False) + #--------------------------------------------------------# + # flops * 2是因为profile没有将卷积作为两个operations + # 有些论文将卷积算乘法、加法两个operations。此时乘2 + # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2 + # 本代码选择乘2,参考YOLOX。 + #--------------------------------------------------------# + flops = flops * 2 + flops, params = clever_format([flops, params], "%.3f") + print('Total GFLOPS: %s' % (flops)) + print('Total params: %s' % (params)) diff --git a/train.py b/train.py new file mode 100644 index 0000000..33696f6 --- /dev/null +++ b/train.py @@ -0,0 +1,255 @@ +import datetime +import os +from functools import partial + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim as optim +from torch.utils.data import DataLoader + +from nets.unet import Unet +from nets.unet_training import get_lr_scheduler, set_optimizer_lr, weights_init +from utils.callbacks import EvalCallback, LossHistory +from utils.dataloader import UnetDataset, unet_dataset_collate +from utils.utils import (download_weights, seed_everything, show_config, + worker_init_fn) +from utils.utils_fit import fit_one_epoch + + +def train(): + if distributed: + dist.init_process_group(backend="nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + device = torch.device("cuda", local_rank) + if local_rank == 0: + print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...") + print("Gpu Device Count : ", ngpus_per_node) + else: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + local_rank = 0 + rank = 0 + + if pretrained: + if distributed: + if local_rank == 0: + download_weights(backbone) + dist.barrier() + else: + download_weights(backbone) + + model = Unet(num_classes=num_classes, pretrained=pretrained, backbone=backbone).train() + if not pretrained: + weights_init(model) + if model_path != '': + + if local_rank == 0: + print('Load weights {}.'.format(model_path)) + + model_dict = model.state_dict() + pretrained_dict = torch.load(model_path, map_location=device) + load_key, no_load_key, temp_dict = [], [], {} + for k, v in pretrained_dict.items(): + if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v): + temp_dict[k] = v + load_key.append(k) + else: + no_load_key.append(k) + model_dict.update(temp_dict) + model.load_state_dict(model_dict) + + if local_rank == 0: + print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key)) + print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key)) + + if local_rank == 0: + time_str = datetime.datetime.strftime(datetime.datetime.now(), '%Y_%m_%d_%H_%M_%S') + log_dir = os.path.join(save_dir, "loss_" + str(time_str)) + loss_history = LossHistory(log_dir, model, input_shape=input_shape) + else: + loss_history = None + + if fp16: + from torch.cuda.amp import GradScaler as GradScaler + + scaler = GradScaler() + else: + scaler = None + + model_train = model.train() + + if sync_bn and ngpus_per_node > 1 and distributed: + model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train) + elif sync_bn: + print("Sync_bn is not support in one gpu or not distributed.") + + if Cuda: + if distributed: + + model_train = model_train.cuda(local_rank) + model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], + find_unused_parameters=True) + else: + model_train = torch.nn.DataParallel(model) + cudnn.benchmark = True + model_train = model_train.cuda() + + with open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/train.txt"), "r") as f: + train_lines = f.readlines() + with open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"), "r") as f: + val_lines = f.readlines() + num_train = len(train_lines) + num_val = len(val_lines) + + if True: + UnFreeze_flag = False + + if Freeze_Train: + model.freeze_backbone() + + batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size + + nbs = 16 + lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1 + lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4 + Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) + Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) + + optimizer = { + 'adam': optim.Adam(model.parameters(), Init_lr_fit, betas=(momentum, 0.999), weight_decay=weight_decay), + 'sgd': optim.SGD(model.parameters(), Init_lr_fit, momentum=momentum, nesterov=True, + weight_decay=weight_decay) + }[optimizer_type] + + lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) + + epoch_step = num_train // batch_size + epoch_step_val = num_val // batch_size + + if epoch_step == 0 or epoch_step_val == 0: + raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") + + train_dataset = UnetDataset(train_lines, input_shape, num_classes, True, VOCdevkit_path) + val_dataset = UnetDataset(val_lines, input_shape, num_classes, False, VOCdevkit_path) + + if distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, ) + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, ) + batch_size = batch_size // ngpus_per_node + shuffle = False + else: + train_sampler = None + val_sampler = None + shuffle = True + + gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, + pin_memory=True, + drop_last=True, collate_fn=unet_dataset_collate, sampler=train_sampler, + worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) + gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, + pin_memory=True, + drop_last=True, collate_fn=unet_dataset_collate, sampler=val_sampler, + worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) + + if local_rank == 0: + eval_callback = EvalCallback(model, input_shape, num_classes, val_lines, VOCdevkit_path, log_dir, Cuda, \ + eval_flag=eval_flag, period=eval_period) + else: + eval_callback = None + + for epoch in range(Init_Epoch, UnFreeze_Epoch): + + if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train: + batch_size = Unfreeze_batch_size + + nbs = 16 + lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1 + lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4 + Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) + Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) + + lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch) + + model.unfreeze_backbone() + + epoch_step = num_train // batch_size + epoch_step_val = num_val // batch_size + + if epoch_step == 0 or epoch_step_val == 0: + raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") + + if distributed: + batch_size = batch_size // ngpus_per_node + + gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, + pin_memory=True, + drop_last=True, collate_fn=unet_dataset_collate, sampler=train_sampler, + worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) + gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, + pin_memory=True, + drop_last=True, collate_fn=unet_dataset_collate, sampler=val_sampler, + worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) + + UnFreeze_flag = True + + if distributed: + train_sampler.set_epoch(epoch) + + set_optimizer_lr(optimizer, lr_scheduler_func, epoch) + + fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, + epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, dice_loss, focal_loss, + cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank) + + if distributed: + dist.barrier() + + if local_rank == 0: + loss_history.writer.close() + + +if __name__ == "__main__": + Cuda = True + seed = 11 + # 是否开启多卡 + distributed = False + sync_bn = False + # 是否使用半精度 + fp16 = True + num_classes = 2 + # 设置骨干网络 + backbone = "vgg" + pretrained = False + model_path = "model_data/8414_8376.pth" + input_shape = [1696, 864] + # 冻结训练 + Init_Epoch = 0 + Freeze_Epoch = 10 + Freeze_batch_size = 1 + # 解冻训练 + UnFreeze_Epoch = 70 + Unfreeze_batch_size = 1 + Freeze_Train = True + # 学习率设置 + Init_lr = 1e-4 + Min_lr = Init_lr * 0.01 + # 优化器 + optimizer_type = "adam" + momentum = 0.9 + weight_decay = 0 + lr_decay_type = 'cos' + save_period = 5 + save_dir = 'logs' + eval_flag = True + eval_period = 5 + # 数据集设置 + VOCdevkit_path = 'VOCdevkit' + dice_loss = False + focal_loss = False + cls_weights = np.ones([num_classes], np.float32) + num_workers = 0 + seed_everything(seed) + ngpus_per_node = torch.cuda.device_count() + train() diff --git a/unet.py b/unet.py new file mode 100644 index 0000000..7920231 --- /dev/null +++ b/unet.py @@ -0,0 +1,131 @@ +import colorsys +import copy + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torch import nn + +from nets.unet import Unet as unet +from utils.utils import cvtColor, preprocess_input, resize_image + + +class Unet(object): + _defaults = { + "model_path": None, + "num_classes": 2, + "backbone": "vgg", + "input_shape": [1696, 864], + "mix_type": 1, + "cuda": True, + } + + def __init__(self, _defaults,**kwargs): + self._defaults = _defaults + self.__dict__.update(self._defaults) + for name, value in kwargs.items(): + setattr(self, name, value) + + if self.num_classes <= 2: + self.colors = [(0, 0, 0), (255,255,255)] + else: + hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)] + self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) + self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors)) + + self.generate() + + + def generate(self, onnx=False): + self.net = unet(num_classes=self.num_classes, backbone=self.backbone) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.net.load_state_dict(torch.load(self.model_path, map_location=device)) + self.net = self.net.eval() + print('{} model, and classes loaded.'.format(self.model_path)) + if not onnx: + if self.cuda: + self.net = nn.DataParallel(self.net) + self.net = self.net.cuda() + + def detect_image(self, image, count=False, name_classes=None): + + image = cvtColor(image) + + old_img = copy.deepcopy(image) + orininal_h = np.array(image).shape[0] + orininal_w = np.array(image).shape[1] + + image_data, nw, nh = resize_image(image, (self.input_shape[1], self.input_shape[0])) + + image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) + + with torch.no_grad(): + images = torch.from_numpy(image_data) + if self.cuda: + images = images.cuda() + + pr = self.net(images)[0] + + pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy() + + pr = pr[int((self.input_shape[0] - nh) // 2): int((self.input_shape[0] - nh) // 2 + nh), \ + int((self.input_shape[1] - nw) // 2): int((self.input_shape[1] - nw) // 2 + nw)] + + pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation=cv2.INTER_LINEAR) + + pr = pr.argmax(axis=-1) + + if count: + classes_nums = np.zeros([self.num_classes]) + total_points_num = orininal_h * orininal_w + print('-' * 63) + print("|%25s | %15s | %15s|" % ("Key", "Value", "Ratio")) + print('-' * 63) + for i in range(self.num_classes): + num = np.sum(pr == i) + ratio = num / total_points_num * 100 + if num > 0: + print("|%25s | %15s | %14.2f%%|" % (str(name_classes[i]), str(num), ratio)) + print('-' * 63) + classes_nums[i] = num + print("classes_nums:", classes_nums) + + if self.mix_type == 0: + # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3)) + # for c in range(self.num_classes): + # seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8') + # seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8') + # seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8') + seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1]) + # ------------------------------------------------# + # 将新图片转换成Image的形式 + # ------------------------------------------------# + image = Image.fromarray(np.uint8(seg_img)) + # ------------------------------------------------# + # 将新图与原图及进行混合 + # ------------------------------------------------# + image = Image.blend(old_img, image, 0.7) + + elif self.mix_type == 1: + # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3)) + # for c in range(self.num_classes): + # seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8') + # seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8') + # seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8') + seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1]) + # ------------------------------------------------# + # 将新图片转换成Image的形式 + # ------------------------------------------------# + image = Image.fromarray(np.uint8(seg_img)) + + elif self.mix_type == 2: + seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8') + # ------------------------------------------------# + # 将新图片转换成Image的形式 + # ------------------------------------------------# + image = Image.fromarray(np.uint8(seg_img)) + + return image \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..4287ca8 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/utils/callbacks.py b/utils/callbacks.py new file mode 100644 index 0000000..73e7b31 --- /dev/null +++ b/utils/callbacks.py @@ -0,0 +1,210 @@ +import os + +import matplotlib +import torch +import torch.nn.functional as F + +matplotlib.use('Agg') +from matplotlib import pyplot as plt +import scipy.signal + +import cv2 +import shutil +import numpy as np + +from PIL import Image +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter +from .utils import cvtColor, preprocess_input, resize_image +from .utils_metrics import compute_mIoU + + +class LossHistory(): + def __init__(self, log_dir, model, input_shape, val_loss_flag=True): + self.log_dir = log_dir + self.val_loss_flag = val_loss_flag + + self.losses = [] + if self.val_loss_flag: + self.val_loss = [] + + os.makedirs(self.log_dir) + self.writer = SummaryWriter(self.log_dir) + try: + dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) + self.writer.add_graph(model, dummy_input) + except: + pass + + def append_loss(self, epoch, loss, val_loss = None): + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + + self.losses.append(loss) + if self.val_loss_flag: + self.val_loss.append(val_loss) + + with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: + f.write(str(loss)) + f.write("\n") + if self.val_loss_flag: + with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: + f.write(str(val_loss)) + f.write("\n") + + self.writer.add_scalar('loss', loss, epoch) + if self.val_loss_flag: + self.writer.add_scalar('val_loss', val_loss, epoch) + + self.loss_plot() + + def loss_plot(self): + iters = range(len(self.losses)) + + plt.figure() + plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') + if self.val_loss_flag: + plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') + + try: + if len(self.losses) < 25: + num = 5 + else: + num = 15 + + plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') + if self.val_loss_flag: + plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') + except: + pass + + plt.grid(True) + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend(loc="upper right") + + plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) + + plt.cla() + plt.close("all") + +class EvalCallback(): + def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \ + miou_out_path=".temp_miou_out", eval_flag=True, period=1): + super(EvalCallback, self).__init__() + + self.net = net + self.input_shape = input_shape + self.num_classes = num_classes + self.image_ids = image_ids + self.dataset_path = dataset_path + self.log_dir = log_dir + self.cuda = cuda + self.miou_out_path = miou_out_path + self.eval_flag = eval_flag + self.period = period + + self.image_ids = [image_id.split()[0] for image_id in image_ids] + self.mious = [0] + self.epoches = [0] + if self.eval_flag: + with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: + f.write(str(0)) + f.write("\n") + + def get_miou_png(self, image): + #---------------------------------------------------------# + # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 + # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB + #---------------------------------------------------------# + image = cvtColor(image) + orininal_h = np.array(image).shape[0] + orininal_w = np.array(image).shape[1] + #---------------------------------------------------------# + # 给图像增加灰条,实现不失真的resize + # 也可以直接resize进行识别 + #---------------------------------------------------------# + image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) + #---------------------------------------------------------# + # 添加上batch_size维度 + #---------------------------------------------------------# + image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) + + with torch.no_grad(): + images = torch.from_numpy(image_data) + if self.cuda: + images = images.cuda() + + #---------------------------------------------------# + # 图片传入网络进行预测 + #---------------------------------------------------# + pr = self.net(images)[0] + #---------------------------------------------------# + # 取出每一个像素点的种类 + #---------------------------------------------------# + pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy() + #--------------------------------------# + # 将灰条部分截取掉 + #--------------------------------------# + pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ + int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] + #---------------------------------------------------# + # 进行图片的resize + #---------------------------------------------------# + pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) + #---------------------------------------------------# + # 取出每一个像素点的种类 + #---------------------------------------------------# + pr = pr.argmax(axis=-1) + + image = Image.fromarray(np.uint8(pr)) + return image + + def on_epoch_end(self, epoch, model_eval): + if epoch % self.period == 0 and self.eval_flag: + self.net = model_eval + gt_dir = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/") + pred_dir = os.path.join(self.miou_out_path, 'detection-results') + if not os.path.exists(self.miou_out_path): + os.makedirs(self.miou_out_path) + if not os.path.exists(pred_dir): + os.makedirs(pred_dir) + print("Get miou.") + for image_id in tqdm(self.image_ids): + #-------------------------------# + # 从文件中读取图像 + #-------------------------------# + image_path = os.path.join(self.dataset_path, "VOC2007/JPEGImages/"+image_id+".jpg") + image = Image.open(image_path) + #------------------------------# + # 获得预测txt + #------------------------------# + image = self.get_miou_png(image) + image.save(os.path.join(pred_dir, image_id + ".png")) + + print("Calculate miou.") + _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数 + temp_miou = np.nanmean(IoUs) * 100 + + self.mious.append(temp_miou) + self.epoches.append(epoch) + + with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: + f.write(str(temp_miou)) + f.write("\n") + + plt.figure() + plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou') + + plt.grid(True) + plt.xlabel('Epoch') + plt.ylabel('Miou') + plt.title('A Miou Curve') + plt.legend(loc="upper right") + + plt.savefig(os.path.join(self.log_dir, "epoch_miou.png")) + plt.cla() + plt.close("all") + + print("Get miou done.") + shutil.rmtree(self.miou_out_path) diff --git a/utils/dataloader.py b/utils/dataloader.py new file mode 100644 index 0000000..9294056 --- /dev/null +++ b/utils/dataloader.py @@ -0,0 +1,149 @@ +import os + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data.dataset import Dataset + +from utils.utils import cvtColor, preprocess_input + + +class UnetDataset(Dataset): + def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path): + super(UnetDataset, self).__init__() + self.annotation_lines = annotation_lines + self.length = len(annotation_lines) + self.input_shape = input_shape + self.num_classes = num_classes + self.train = train + self.dataset_path = dataset_path + + def __len__(self): + return self.length + + def __getitem__(self, index): + annotation_line = self.annotation_lines[index] + name = annotation_line.split()[0] + + #-------------------------------# + # 从文件中读取图像 + #-------------------------------# + jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg")) + png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png")) + #-------------------------------# + # 数据增强 + #-------------------------------# + jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train) + + jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1]) + png = np.array(png) + png[png >= self.num_classes] = self.num_classes + #-------------------------------------------------------# + # 转化成one_hot的形式 + # 在这里需要+1是因为voc数据集有些标签具有白边部分 + # 我们需要将白边部分进行忽略,+1的目的是方便忽略。 + #-------------------------------------------------------# + seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])] + seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1)) + + return jpg, png, seg_labels + + def rand(self, a=0, b=1): + return np.random.rand() * (b - a) + a + + def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True): + image = cvtColor(image) + label = Image.fromarray(np.array(label)) + #------------------------------# + # 获得图像的高宽与目标高宽 + #------------------------------# + iw, ih = image.size + h, w = input_shape + + if not random: + iw, ih = image.size + scale = min(w/iw, h/ih) + nw = int(iw*scale) + nh = int(ih*scale) + + image = image.resize((nw,nh), Image.BICUBIC) + new_image = Image.new('RGB', [w, h], (128,128,128)) + new_image.paste(image, ((w-nw)//2, (h-nh)//2)) + + label = label.resize((nw,nh), Image.NEAREST) + new_label = Image.new('L', [w, h], (0)) + new_label.paste(label, ((w-nw)//2, (h-nh)//2)) + return new_image, new_label + + #------------------------------------------# + # 对图像进行缩放并且进行长和宽的扭曲 + #------------------------------------------# + new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) + scale = self.rand(0.25, 2) + if new_ar < 1: + nh = int(scale*h) + nw = int(nh*new_ar) + else: + nw = int(scale*w) + nh = int(nw/new_ar) + image = image.resize((nw,nh), Image.BICUBIC) + label = label.resize((nw,nh), Image.NEAREST) + + #------------------------------------------# + # 翻转图像 + #------------------------------------------# + flip = self.rand()<.5 + if flip: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + label = label.transpose(Image.FLIP_LEFT_RIGHT) + + #------------------------------------------# + # 将图像多余的部分加上灰条 + #------------------------------------------# + dx = int(self.rand(0, w-nw)) + dy = int(self.rand(0, h-nh)) + new_image = Image.new('RGB', (w,h), (128,128,128)) + new_label = Image.new('L', (w,h), (0)) + new_image.paste(image, (dx, dy)) + new_label.paste(label, (dx, dy)) + image = new_image + label = new_label + + image_data = np.array(image, np.uint8) + #---------------------------------# + # 对图像进行色域变换 + # 计算色域变换的参数 + #---------------------------------# + r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 + #---------------------------------# + # 将图像转到HSV上 + #---------------------------------# + hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV)) + dtype = image_data.dtype + #---------------------------------# + # 应用变换 + #---------------------------------# + x = np.arange(0, 256, dtype=r.dtype) + lut_hue = ((x * r[0]) % 180).astype(dtype) + lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) + lut_val = np.clip(x * r[2], 0, 255).astype(dtype) + + image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) + image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) + + return image_data, label + +# DataLoader中collate_fn使用 +def unet_dataset_collate(batch): + images = [] + pngs = [] + seg_labels = [] + for img, png, labels in batch: + images.append(img) + pngs.append(png) + seg_labels.append(labels) + images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) + pngs = torch.from_numpy(np.array(pngs)).long() + seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor) + return images, pngs, seg_labels diff --git a/utils/dataloader_medical.py b/utils/dataloader_medical.py new file mode 100644 index 0000000..50a3dea --- /dev/null +++ b/utils/dataloader_medical.py @@ -0,0 +1,150 @@ +import os + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data.dataset import Dataset + +from utils.utils import cvtColor, preprocess_input + + +class UnetDataset(Dataset): + def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path): + super(UnetDataset, self).__init__() + self.annotation_lines = annotation_lines + self.length = len(annotation_lines) + self.input_shape = input_shape + self.num_classes = num_classes + self.train = train + self.dataset_path = dataset_path + + def __len__(self): + return self.length + + def __getitem__(self, index): + annotation_line = self.annotation_lines[index] + name = annotation_line.split()[0] + + #-------------------------------# + # 从文件中读取图像 + #-------------------------------# + jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "Images"), name + ".png")) + png = Image.open(os.path.join(os.path.join(self.dataset_path, "Labels"), name + ".png")) + #-------------------------------# + # 数据增强 + #-------------------------------# + jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train) + + jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1]) + png = np.array(png) + #-------------------------------------------------------# + # 这里的标签处理方式和普通voc的处理方式不同 + # 将小于127.5的像素点设置为目标像素点。 + #-------------------------------------------------------# + modify_png = np.zeros_like(png) + modify_png[png <= 127.5] = 1 + seg_labels = modify_png + seg_labels = np.eye(self.num_classes + 1)[seg_labels.reshape([-1])] + seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1)) + + return jpg, modify_png, seg_labels + + def rand(self, a=0, b=1): + return np.random.rand() * (b - a) + a + + def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True): + image = cvtColor(image) + label = Image.fromarray(np.array(label)) + #------------------------------# + # 获得图像的高宽与目标高宽 + #------------------------------# + iw, ih = image.size + h, w = input_shape + + if not random: + iw, ih = image.size + scale = min(w/iw, h/ih) + nw = int(iw*scale) + nh = int(ih*scale) + + image = image.resize((nw,nh), Image.BICUBIC) + new_image = Image.new('RGB', [w, h], (128,128,128)) + new_image.paste(image, ((w-nw)//2, (h-nh)//2)) + + label = label.resize((nw,nh), Image.NEAREST) + new_label = Image.new('L', [w, h], (0)) + new_label.paste(label, ((w-nw)//2, (h-nh)//2)) + return new_image, new_label + + #------------------------------------------# + # 对图像进行缩放并且进行长和宽的扭曲 + #------------------------------------------# + new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) + scale = self.rand(0.25, 2) + if new_ar < 1: + nh = int(scale*h) + nw = int(nh*new_ar) + else: + nw = int(scale*w) + nh = int(nw/new_ar) + image = image.resize((nw,nh), Image.BICUBIC) + label = label.resize((nw,nh), Image.NEAREST) + + #------------------------------------------# + # 翻转图像 + #------------------------------------------# + flip = self.rand()<.5 + if flip: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + label = label.transpose(Image.FLIP_LEFT_RIGHT) + + #------------------------------------------# + # 将图像多余的部分加上灰条 + #------------------------------------------# + dx = int(self.rand(0, w-nw)) + dy = int(self.rand(0, h-nh)) + new_image = Image.new('RGB', (w,h), (128,128,128)) + new_label = Image.new('L', (w,h), (0)) + new_image.paste(image, (dx, dy)) + new_label.paste(label, (dx, dy)) + image = new_image + label = new_label + + image_data = np.array(image, np.uint8) + #---------------------------------# + # 对图像进行色域变换 + # 计算色域变换的参数 + #---------------------------------# + r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 + #---------------------------------# + # 将图像转到HSV上 + #---------------------------------# + hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV)) + dtype = image_data.dtype + #---------------------------------# + # 应用变换 + #---------------------------------# + x = np.arange(0, 256, dtype=r.dtype) + lut_hue = ((x * r[0]) % 180).astype(dtype) + lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) + lut_val = np.clip(x * r[2], 0, 255).astype(dtype) + + image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) + image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) + + return image_data, label + +# DataLoader中collate_fn使用 +def unet_dataset_collate(batch): + images = [] + pngs = [] + seg_labels = [] + for img, png, labels in batch: + images.append(img) + pngs.append(png) + seg_labels.append(labels) + images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) + pngs = torch.from_numpy(np.array(pngs)).long() + seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor) + return images, pngs, seg_labels diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..80a60cb --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,76 @@ +import random + +import numpy as np +import torch +from PIL import Image + + +def cvtColor(image): + if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: + return image + else: + image = image.convert('RGB') + return image + +def resize_image(image, size): + iw, ih = image.size + w, h = size + + scale = min(w/iw, h/ih) + nw = int(iw*scale) + nh = int(ih*scale) + + image = image.resize((nw,nh), Image.BICUBIC) + new_image = Image.new('RGB', size, (128,128,128)) + new_image.paste(image, ((w-nw)//2, (h-nh)//2)) + + return new_image, nw, nh + + +def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group['lr'] + + +def seed_everything(seed=11): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def worker_init_fn(worker_id, rank, seed): + worker_seed = rank + seed + random.seed(worker_seed) + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + +def preprocess_input(image): + image /= 255.0 + return image + +def show_config(**kwargs): + print('Configurations:') + print('-' * 70) + print('|%25s | %40s|' % ('keys', 'values')) + print('-' * 70) + for key, value in kwargs.items(): + print('|%25s | %40s|' % (str(key), str(value))) + print('-' * 70) + +def download_weights(backbone, model_dir="./model_data"): + import os + from torch.hub import load_state_dict_from_url + + download_urls = { + 'vgg' : 'https://download.pytorch.org/models/vgg16-397923af.pth', + 'resnet50' : 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth' + } + url = download_urls[backbone] + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + load_state_dict_from_url(url, model_dir) \ No newline at end of file diff --git a/utils/utils_fit.py b/utils/utils_fit.py new file mode 100644 index 0000000..2e9c885 --- /dev/null +++ b/utils/utils_fit.py @@ -0,0 +1,272 @@ +import os + +import torch +from nets.unet_training import CE_Loss, Dice_loss, Focal_Loss +from tqdm import tqdm + +from utils.utils import get_lr +from utils.utils_metrics import f_score + + +def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0): + total_loss = 0 + total_f_score = 0 + + val_loss = 0 + val_f_score = 0 + + if local_rank == 0: + print('Start Train') + pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) + model_train.train() + for iteration, batch in enumerate(gen): + if iteration >= epoch_step: + break + imgs, pngs, labels = batch + with torch.no_grad(): + weights = torch.from_numpy(cls_weights) + if cuda: + imgs = imgs.cuda(local_rank) + pngs = pngs.cuda(local_rank) + labels = labels.cuda(local_rank) + weights = weights.cuda(local_rank) + + optimizer.zero_grad() + if not fp16: + #----------------------# + # 前向传播 + #----------------------# + outputs = model_train(imgs) + #----------------------# + # 损失计算 + #----------------------# + if focal_loss: + loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) + else: + loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) + + if dice_loss: + main_dice = Dice_loss(outputs, labels) + loss = loss + main_dice + + with torch.no_grad(): + #-------------------------------# + # 计算f_score + #-------------------------------# + _f_score = f_score(outputs, labels) + + loss.backward() + optimizer.step() + else: + from torch.cuda.amp import autocast + with autocast(): + #----------------------# + # 前向传播 + #----------------------# + outputs = model_train(imgs) + #----------------------# + # 损失计算 + #----------------------# + if focal_loss: + loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) + else: + loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) + + if dice_loss: + main_dice = Dice_loss(outputs, labels) + loss = loss + main_dice + + with torch.no_grad(): + #-------------------------------# + # 计算f_score + #-------------------------------# + _f_score = f_score(outputs, labels) + + #----------------------# + # 反向传播 + #----------------------# + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + total_loss += loss.item() + total_f_score += _f_score.item() + + if local_rank == 0: + pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1), + 'f_score' : total_f_score / (iteration + 1), + 'lr' : get_lr(optimizer)}) + pbar.update(1) + + if local_rank == 0: + pbar.close() + print('Finish Train') + print('Start Validation') + pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) + + model_train.eval() + for iteration, batch in enumerate(gen_val): + if iteration >= epoch_step_val: + break + imgs, pngs, labels = batch + with torch.no_grad(): + weights = torch.from_numpy(cls_weights) + if cuda: + imgs = imgs.cuda(local_rank) + pngs = pngs.cuda(local_rank) + labels = labels.cuda(local_rank) + weights = weights.cuda(local_rank) + + #----------------------# + # 前向传播 + #----------------------# + outputs = model_train(imgs) + #----------------------# + # 损失计算 + #----------------------# + if focal_loss: + loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) + else: + loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) + + if dice_loss: + main_dice = Dice_loss(outputs, labels) + loss = loss + main_dice + #-------------------------------# + # 计算f_score + #-------------------------------# + _f_score = f_score(outputs, labels) + + val_loss += loss.item() + val_f_score += _f_score.item() + + if local_rank == 0: + pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1), + 'f_score' : val_f_score / (iteration + 1), + 'lr' : get_lr(optimizer)}) + pbar.update(1) + + if local_rank == 0: + pbar.close() + print('Finish Validation') + loss_history.append_loss(epoch + 1, total_loss/ epoch_step, val_loss/ epoch_step_val) + eval_callback.on_epoch_end(epoch + 1, model_train) + print('Epoch:'+ str(epoch+1) + '/' + str(Epoch)) + print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val)) + + #-----------------------------------------------# + # 保存权值 + #-----------------------------------------------# + if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: + torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth'%((epoch + 1), total_loss / epoch_step, val_loss / epoch_step_val))) + + if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): + print('Save best model to best_epoch_weights.pth') + torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth")) + + torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth")) + +def fit_one_epoch_no_val(model_train, model, loss_history, optimizer, epoch, epoch_step, gen, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0): + total_loss = 0 + total_f_score = 0 + + if local_rank == 0: + print('Start Train') + pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) + model_train.train() + for iteration, batch in enumerate(gen): + if iteration >= epoch_step: + break + imgs, pngs, labels = batch + with torch.no_grad(): + weights = torch.from_numpy(cls_weights) + if cuda: + imgs = imgs.cuda(local_rank) + pngs = pngs.cuda(local_rank) + labels = labels.cuda(local_rank) + weights = weights.cuda(local_rank) + + optimizer.zero_grad() + if not fp16: + #----------------------# + # 前向传播 + #----------------------# + outputs = model_train(imgs) + #----------------------# + # 损失计算 + #----------------------# + if focal_loss: + loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) + else: + loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) + + if dice_loss: + main_dice = Dice_loss(outputs, labels) + loss = loss + main_dice + + with torch.no_grad(): + #-------------------------------# + # 计算f_score + #-------------------------------# + _f_score = f_score(outputs, labels) + + loss.backward() + optimizer.step() + else: + from torch.cuda.amp import autocast + with autocast(): + #----------------------# + # 前向传播 + #----------------------# + outputs = model_train(imgs) + #----------------------# + # 损失计算 + #----------------------# + if focal_loss: + loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) + else: + loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) + + if dice_loss: + main_dice = Dice_loss(outputs, labels) + loss = loss + main_dice + + with torch.no_grad(): + #-------------------------------# + # 计算f_score + #-------------------------------# + _f_score = f_score(outputs, labels) + + #----------------------# + # 反向传播 + #----------------------# + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + total_loss += loss.item() + total_f_score += _f_score.item() + + if local_rank == 0: + pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1), + 'f_score' : total_f_score / (iteration + 1), + 'lr' : get_lr(optimizer)}) + pbar.update(1) + + if local_rank == 0: + pbar.close() + loss_history.append_loss(epoch + 1, total_loss/ epoch_step) + print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) + print('Total Loss: %.3f' % (total_loss / epoch_step)) + + #-----------------------------------------------# + # 保存权值 + #-----------------------------------------------# + if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: + torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f.pth'%((epoch + 1), total_loss / epoch_step))) + + if len(loss_history.losses) <= 1 or (total_loss / epoch_step) <= min(loss_history.losses): + print('Save best model to best_epoch_weights.pth') + torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth")) + + torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth")) \ No newline at end of file diff --git a/utils/utils_metrics.py b/utils/utils_metrics.py new file mode 100644 index 0000000..e8ddd68 --- /dev/null +++ b/utils/utils_metrics.py @@ -0,0 +1,182 @@ +import csv +import os +from os.path import join + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + + +def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5): + n, c, h, w = inputs.size() + nt, ht, wt, ct = target.size() + if h != ht and w != wt: + inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) + + temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1) + temp_target = target.view(n, -1, ct) + + #--------------------------------------------# + # 计算dice系数 + #--------------------------------------------# + temp_inputs = torch.gt(temp_inputs, threhold).float() + tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1]) + fp = torch.sum(temp_inputs , axis=[0,1]) - tp + fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp + + score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) + score = torch.mean(score) + return score + +# 设标签宽W,长H +def fast_hist(a, b, n): + #--------------------------------------------------------------------------------# + # a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,) + #--------------------------------------------------------------------------------# + k = (a >= 0) & (a < n) + #--------------------------------------------------------------------------------# + # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) + # 返回中,写对角线上的为分类正确的像素点 + #--------------------------------------------------------------------------------# + return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) + +def per_class_iu(hist): + return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) + +def per_class_PA_Recall(hist): + return np.diag(hist) / np.maximum(hist.sum(1), 1) + +def per_class_Precision(hist): + return np.diag(hist) / np.maximum(hist.sum(0), 1) + +def per_Accuracy(hist): + return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1) + +def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None): + print('Num classes', num_classes) + #-----------------------------------------# + # 创建一个全是0的矩阵,是一个混淆矩阵 + #-----------------------------------------# + hist = np.zeros((num_classes, num_classes)) + + #------------------------------------------------# + # 获得验证集标签路径列表,方便直接读取 + # 获得验证集图像分割结果路径列表,方便直接读取 + #------------------------------------------------# + gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list] + pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list] + + #------------------------------------------------# + # 读取每一个(图片-标签)对 + #------------------------------------------------# + for ind in range(len(gt_imgs)): + #------------------------------------------------# + # 读取一张图像分割结果,转化成numpy数组 + #------------------------------------------------# + pred = np.array(Image.open(pred_imgs[ind])) + #------------------------------------------------# + # 读取一张对应的标签,转化成numpy数组 + #------------------------------------------------# + label = np.array(Image.open(gt_imgs[ind])) + + # 如果图像分割结果与标签的大小不一样,这张图片就不计算 + if len(label.flatten()) != len(pred.flatten()): + print( + 'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format( + len(label.flatten()), len(pred.flatten()), gt_imgs[ind], + pred_imgs[ind])) + continue + + #------------------------------------------------# + # 对一张图片计算21×21的hist矩阵,并累加 + #------------------------------------------------# + hist += fast_hist(label.flatten(), pred.flatten(), num_classes) + # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值 + if name_classes is not None and ind > 0 and ind % 10 == 0: + print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format( + ind, + len(gt_imgs), + 100 * np.nanmean(per_class_iu(hist)), + 100 * np.nanmean(per_class_PA_Recall(hist)), + 100 * per_Accuracy(hist) + ) + ) + #------------------------------------------------# + # 计算所有验证集图片的逐类别mIoU值 + #------------------------------------------------# + IoUs = per_class_iu(hist) + PA_Recall = per_class_PA_Recall(hist) + Precision = per_class_Precision(hist) + #------------------------------------------------# + # 逐类别输出一下mIoU值 + #------------------------------------------------# + if name_classes is not None: + for ind_class in range(num_classes): + print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \ + + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2))) + + #-----------------------------------------------------------------# + # 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值 + #-----------------------------------------------------------------# + print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2))) + return np.array(hist, np.int), IoUs, PA_Recall, Precision + +def adjust_axes(r, t, fig, axes): + bb = t.get_window_extent(renderer=r) + text_width_inches = bb.width / fig.dpi + current_fig_width = fig.get_figwidth() + new_fig_width = current_fig_width + text_width_inches + propotion = new_fig_width / current_fig_width + x_lim = axes.get_xlim() + axes.set_xlim([x_lim[0], x_lim[1] * propotion]) + +def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True): + fig = plt.gcf() + axes = plt.gca() + plt.barh(range(len(values)), values, color='royalblue') + plt.title(plot_title, fontsize=tick_font_size + 2) + plt.xlabel(x_label, fontsize=tick_font_size) + plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size) + r = fig.canvas.get_renderer() + for i, val in enumerate(values): + str_val = " " + str(val) + if val < 1.0: + str_val = " {0:.2f}".format(val) + t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold') + if i == (len(values)-1): + adjust_axes(r, t, fig, axes) + + fig.tight_layout() + fig.savefig(output_path) + if plt_show: + plt.show() + plt.close() + +def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12): + draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \ + os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True) + print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png")) + + draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Pixel Accuracy", \ + os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False) + print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png")) + + draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Recall", \ + os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False) + print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png")) + + draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \ + os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False) + print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png")) + + with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f: + writer = csv.writer(f) + writer_list = [] + writer_list.append([' '] + [str(c) for c in name_classes]) + for i in range(len(hist)): + writer_list.append([name_classes[i]] + [str(x) for x in hist[i]]) + writer.writerows(writer_list) + print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv")) + \ No newline at end of file diff --git a/web.py b/web.py new file mode 100644 index 0000000..31f5f9f --- /dev/null +++ b/web.py @@ -0,0 +1,191 @@ +import os +import streamlit as st +import cv2 +import tempfile +import torch +import numpy as np +from PIL.Image import Image +from torchvision import transforms +from PIL import Image +from unet import Unet +from nets.U_ConvAutoencoder import U_ConvAutoencoder +from typing import Tuple, List + +# Constants and configuration +DEFAULTS = { + "model_path": 'model_data/8414_8376.pth', + "num_classes": 2, + "backbone": "vgg", + "input_shape": [1696, 864], + "mix_type": 1, + "cuda": torch.cuda.is_available(), +} + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +TRANSFORM = transforms.Compose([ + transforms.Resize((1728, 3392)), + transforms.ToTensor() +]) + + +class PreCA: + model: U_ConvAutoencoder = None + + @classmethod + def initialize_model(cls, u_ca_path: str) -> None: + cls.model = U_ConvAutoencoder().to(DEVICE) + cls.model.load_state_dict(torch.load(u_ca_path, map_location=DEVICE)) + cls.model.eval() + + @classmethod + def unload_model(cls) -> None: + cls.model = None + torch.cuda.empty_cache() + + @classmethod + def load_image(cls, image: Image.Image) -> torch.Tensor: + image = image.convert("L") + image = TRANSFORM(image).unsqueeze(0) + return image.to(DEVICE) + + @staticmethod + def ca_smooth(image: Image.Image) -> Image.Image: + image_cv2 = np.array(image) + closed_image = cv2.morphologyEx(image_cv2, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))) + blurred = cv2.GaussianBlur(closed_image, (1, 1), 0) + th = cv2.threshold(blurred, 126, 255, cv2.THRESH_BINARY)[1] + return Image.fromarray(th) + + @classmethod + def infer(cls, image: Image.Image) -> Image.Image: + image_tensor = cls.load_image(image) + with torch.no_grad(): + output = cls.model(image_tensor) + output = output.squeeze(0).cpu() + output_image = transforms.ToPILImage()(output) + return output_image.resize((3384, 1710), Image.NEAREST) + + +class PreUnet: + + @staticmethod + def calculate_metrics(pred_image: Image.Image, true_image: Image.Image, threshold: int = 1) -> Tuple[int, int, int]: + pred_binary = pred_image.convert('L').point(lambda x: 0 if x < threshold else 255) + true_binary = true_image.convert('L').point(lambda x: 0 if x < threshold else 255) + + pred_array = np.array(pred_binary) + true_array = np.array(true_binary) + + TP = np.sum((pred_array == 255) & (true_array == 255)) + FP = np.sum((pred_array == 255) & (true_array == 0)) + FN = np.sum((pred_array == 0) & (true_array == 255)) + + return TP, FP, FN + + @staticmethod + def apply_mask(original_image: Image.Image, mask_image: Image.Image) -> Image.Image: + original_image = original_image.convert("RGB").resize((3384, 1710), Image.NEAREST) + mask_image = mask_image.convert("RGB").resize((3384, 1710), Image.NEAREST) + + original_array = np.array(original_image) + mask_array = np.array(mask_image) + + mask = np.all(mask_array == [255, 255, 255], axis=-1) + original_array[mask] = [0, 255, 0] + + return Image.fromarray(original_array) + + @classmethod + def process_image(cls, image: Image.Image, unet): + detected_image = unet.detect_image(image) + inferred_image = PreCA.infer(detected_image) + smoothed_image = PreCA.ca_smooth(inferred_image) + return cls.apply_mask(image, smoothed_image),smoothed_image + + +def main_page(): + st.title('自动驾驶车道线自动检测与增强') + stframe = st.empty() + st.sidebar.subheader("参数设置") + + is_pre = st.sidebar.checkbox('开启预测') + unet = Unet(DEFAULTS) if is_pre else None + + if is_pre: + u_ca_path = 'weights/best_conv_autoencoder1.pth' + PreCA.initialize_model(u_ca_path) + else: + PreCA.unload_model() + + st.sidebar.subheader("图像检测") + image_dir_path = st.sidebar.text_input('请输入图像文件夹路径:') + is_get_iou = st.sidebar.checkbox('开启计算IOU') + label_dir_path = st.sidebar.text_input('请输入标签文件夹路径:') if is_get_iou else None + btn_click = st.sidebar.button("开始预测") + + if btn_click: + process_images(image_dir_path, label_dir_path, unet, is_pre, is_get_iou, stframe) + + st.sidebar.subheader("视频检测") + uploaded_video = st.sidebar.file_uploader("上传视频:", type=['mp4', 'mov', 'avi', 'mkv', 'flv', 'wmv', 'm4v']) + + if uploaded_video is not None: + process_video(uploaded_video, unet, is_pre, stframe) + + +def process_images(image_dir_path, label_dir_path, unet, is_pre, is_get_iou, stframe): + ious = [] + img_names = os.listdir(image_dir_path) + iou_text = st.empty() + + for img_name in img_names: + if img_name.lower().endswith( + ('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): + image_path = os.path.join(image_dir_path, img_name) + image = Image.open(image_path) + if is_pre: + result_image,smoothed_image = PreUnet.process_image(image, unet) + stframe.image([image, result_image], width=640) + if is_get_iou and label_dir_path: + label_path = os.path.join(label_dir_path, f"{os.path.splitext(img_name)[0]}_bin.png") + label = Image.open(label_path) + TP, FP, FN = PreUnet.calculate_metrics(smoothed_image, label) + iou = TP / (TP + FP + FN) + # ious.append(iou) + iou_text.text(f'当前IOU: {iou}') + else: + stframe.image(image, width=1024) + + +def process_video(uploaded_video, unet, is_pre, stframe): + tfile = tempfile.NamedTemporaryFile(delete=False) + tfile.write(uploaded_video.read()) + tfile.close() + + cap = cv2.VideoCapture(tfile.name) + + if 'frame_pos' not in st.session_state: + st.session_state.frame_pos = 0 + + cap.set(cv2.CAP_PROP_POS_FRAMES, st.session_state.frame_pos) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + st.session_state.frame_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + + if is_pre: + processed_frame,smoothed_image = PreUnet.process_image(frame, unet) + stframe.image(processed_frame, width=1024,use_column_width=False) + else: + stframe.image(frame, width=1024,use_column_width=False) + + cap.release() + + +if __name__ == '__main__': + main_page() diff --git a/图片修改.py b/图片修改.py new file mode 100644 index 0000000..04bc911 --- /dev/null +++ b/图片修改.py @@ -0,0 +1,78 @@ +import os +import cv2 +import numpy as np +from concurrent.futures import ThreadPoolExecutor, as_completed +from tqdm import tqdm + + +def remove_foreground_and_fill(image_path, mask_path, crop_size): + # 读取图像和掩模 + image = cv2.imread(image_path) + mask = cv2.imread(mask_path, 0) # 假设掩模是灰度图 + + # 确保掩模是二值化的 + _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) + + # 随机生成裁剪区域的起始点 + h, w = mask.shape + start_x = np.random.randint(0, w - crop_size[1] + 1) + start_y = np.random.randint(0, h - crop_size[0] + 1) + + # 裁剪掩模和图像 + cropped_mask = mask[start_y:start_y + crop_size[0], start_x:start_x + crop_size[1]] + cropped_image = image[start_y:start_y + crop_size[0], start_x:start_x + crop_size[1]] + + # 使用inpaint方法进行前景消除并填补 + result = cv2.inpaint(cropped_image, cropped_mask, 3, cv2.INPAINT_TELEA) + + # 将填补后的图像放回原图的相应位置 + result_image = image.copy() + result_image[start_y:start_y + crop_size[0], start_x:start_x + crop_size[1]] = result + + return result_image + + +def process_image(img_name, img_folder, label_folder, save_dir): + if img_name.lower().endswith('.jpg'): + label_name = img_name[:-4] + '_bin.png' # 构造标签文件名 + if label_name in os.listdir(label_folder): # 确保标签文件存在 + # 调用remove_foreground_and_fill函数处理图像和掩模 + result_image = remove_foreground_and_fill( + os.path.join(img_folder, img_name), + os.path.join(label_folder, label_name), + (500, 1710) # 定义裁剪尺寸为500x1710像素 + ) + + # 保存处理后的图像到save_files文件夹 + save_path = os.path.join(save_dir, img_name) + cv2.imwrite(save_path, result_image) + + +def write_img_label_txt(base_dir, dataset_type): + # 创建保存txt文件的目录 + save_dir = os.path.join(base_dir, 'save_files') + os.makedirs(save_dir, exist_ok=True) + + # 获取img和label文件夹的路径 + img_folder = os.path.join(base_dir, dataset_type, 'img') + label_folder = os.path.join(base_dir, dataset_type, 'label') + + img_names = [img_name for img_name in os.listdir(img_folder) if img_name.lower().endswith('.jpg')] + + with ThreadPoolExecutor() as executor: + futures = [] + for img_name in img_names: + futures.append(executor.submit(process_image, img_name, img_folder, label_folder, save_dir)) + + for future in tqdm(as_completed(futures), total=len(futures), desc="Processing images"): + future.result() + + +# 基础目录 +base_dir = r'E:\git\unet_seg\unet\original_data\dataset_A' + +# 处理test和train文件夹 +for dataset_type in ['train']: + write_img_label_txt(base_dir, dataset_type) + +print('All images have been processed and saved.')