提交文件
This commit is contained in:
parent
faffbfd886
commit
db2ff6a3ff
|
@ -0,0 +1,199 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchvision import transforms
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm # 导入tqdm库
|
||||||
|
|
||||||
|
# 检查是否有可用的GPU
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# 定义卷积自编码器
|
||||||
|
class ConvAutoencoder(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(ConvAutoencoder, self).__init__()
|
||||||
|
# 编码器
|
||||||
|
self.encoder1 = nn.Sequential(
|
||||||
|
nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # output: 16 x 1692 x 855
|
||||||
|
nn.BatchNorm2d(16),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.encoder2 = nn.Sequential(
|
||||||
|
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # output: 32 x 846 x 428
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.encoder3 = nn.Sequential(
|
||||||
|
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # output: 64 x 423 x 214
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.encoder4 = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # output: 128 x 212 x 107
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.encoder5 = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # output: 256 x 106 x 54
|
||||||
|
nn.BatchNorm2d(256),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解码器
|
||||||
|
self.decoder5 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 128 x 212 x 107
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.decoder4 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(256, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 64 x 423 x 214
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.decoder3 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 32 x 846 x 428
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.decoder2 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 16 x 1692 x 855
|
||||||
|
nn.BatchNorm2d(16),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.decoder1 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 1 x 3384 x 1710
|
||||||
|
nn.Sigmoid() # 使用Sigmoid以确保输出在[0, 1]范围内
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# 编码器
|
||||||
|
enc1 = self.encoder1(x)
|
||||||
|
enc2 = self.encoder2(enc1)
|
||||||
|
enc3 = self.encoder3(enc2)
|
||||||
|
enc4 = self.encoder4(enc3)
|
||||||
|
enc5 = self.encoder5(enc4)
|
||||||
|
|
||||||
|
# 解码器
|
||||||
|
dec5 = self.decoder5(enc5)
|
||||||
|
dec4 = self.decoder4(torch.cat([dec5, enc4], dim=1))
|
||||||
|
dec3 = self.decoder3(torch.cat([dec4, enc3], dim=1))
|
||||||
|
dec2 = self.decoder2(torch.cat([dec3, enc2], dim=1))
|
||||||
|
dec1 = self.decoder1(torch.cat([dec2, enc1], dim=1))
|
||||||
|
|
||||||
|
return dec1
|
||||||
|
|
||||||
|
|
||||||
|
# 自定义数据集加载器
|
||||||
|
class CustomImageDataset(Dataset):
|
||||||
|
def __init__(self, image_dir, label_dir, transform=None):
|
||||||
|
self.image_dir = image_dir
|
||||||
|
self.label_dir = label_dir
|
||||||
|
self.transform = transform
|
||||||
|
self.image_names = os.listdir(image_dir)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_names)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img_name = self.image_names[idx]
|
||||||
|
img_path = os.path.join(self.image_dir, img_name)
|
||||||
|
image = Image.open(img_path).convert("L")
|
||||||
|
|
||||||
|
label_name = img_name # 假设图像名与标签名匹配
|
||||||
|
label_path = os.path.join(self.label_dir, label_name)
|
||||||
|
label_image = Image.open(label_path).convert("L")
|
||||||
|
|
||||||
|
if self.transform:
|
||||||
|
image = self.transform(image)
|
||||||
|
label_image = self.transform(label_image)
|
||||||
|
|
||||||
|
return image, label_image
|
||||||
|
def count_parameters(model):
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
# IoU计算函数
|
||||||
|
def compute_iou(pred, target, threshold=0.5):
|
||||||
|
pred = (pred > threshold).float()
|
||||||
|
target = (target > threshold).float()
|
||||||
|
|
||||||
|
intersection = (pred * target).sum()
|
||||||
|
union = pred.sum() + target.sum() - intersection
|
||||||
|
|
||||||
|
if union == 0:
|
||||||
|
return 1.0 if intersection == 0 else 0.0
|
||||||
|
|
||||||
|
iou = intersection / union
|
||||||
|
return iou.item()
|
||||||
|
|
||||||
|
# 图像预处理和数据加载
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((1728, 3392)),
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
running_loss = 0.0
|
||||||
|
total_iou = 0.0
|
||||||
|
progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
|
||||||
|
for data in progress_bar:
|
||||||
|
imgs, label_imgs = data
|
||||||
|
imgs, label_imgs = imgs.to(device), label_imgs.to(device) # 将数据移动到GPU
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
output = model(imgs)
|
||||||
|
loss = criterion(output, label_imgs)
|
||||||
|
|
||||||
|
# 计算IoU
|
||||||
|
iou = compute_iou(output, label_imgs)
|
||||||
|
total_iou += iou
|
||||||
|
|
||||||
|
# 反向传播和优化
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
|
||||||
|
# 更新进度条描述
|
||||||
|
progress_bar.set_postfix(loss=loss.item(), iou=iou)
|
||||||
|
|
||||||
|
# 打印损失和IoU
|
||||||
|
epoch_loss = running_loss / len(data_loader)
|
||||||
|
epoch_iou = total_iou / len(data_loader)
|
||||||
|
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}, IoU: {epoch_iou}')
|
||||||
|
|
||||||
|
# 保存最佳IoU模型
|
||||||
|
if epoch_iou > best_iou:
|
||||||
|
best_iou = epoch_iou
|
||||||
|
torch.save(model.state_dict(), best_model_path)
|
||||||
|
print(f'Best model saved with IoU: {best_iou}')
|
||||||
|
|
||||||
|
# 每5个epoch保存一次模型权重
|
||||||
|
if (epoch + 1) % 5 == 0:
|
||||||
|
torch.save(model.state_dict(), f'out_weights/conv_autoencoder_epoch_{epoch+1}.pth')
|
||||||
|
print(f'Model weights saved at epoch {epoch+1}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
image_dir = './img' # 替换为你的图像文件夹路径
|
||||||
|
label_dir = './label' # 替换为你的标签文件夹路径
|
||||||
|
|
||||||
|
dataset = CustomImageDataset(image_dir, label_dir, transform=transform)
|
||||||
|
data_loader = DataLoader(dataset, batch_size=15, shuffle=True)
|
||||||
|
|
||||||
|
# 实例化模型、定义损失函数和优化器
|
||||||
|
model = ConvAutoencoder().to(device) # 将模型移动到GPU
|
||||||
|
print(f'The model has {count_parameters(model):,} trainable parameters')
|
||||||
|
criterion = nn.BCELoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
||||||
|
|
||||||
|
# 保存最佳IoU模型
|
||||||
|
best_iou = 0.0
|
||||||
|
best_model_path = 'out_weights/best_conv_autoencoder.pth'
|
||||||
|
# 训练卷积自编码器
|
||||||
|
num_epochs = 100
|
||||||
|
train()
|
|
@ -0,0 +1,107 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import uuid
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def trans_255_1(image):
|
||||||
|
if len(image.shape) == 2: # 单通道图像
|
||||||
|
mask = image == 255
|
||||||
|
image[mask] = 1
|
||||||
|
elif len(image.shape) == 3 and image.shape[2] == 3: # 三通道图像
|
||||||
|
white = np.array([255, 255, 255])
|
||||||
|
mask = np.all(image == white, axis=-1)
|
||||||
|
image[mask] = [1, 1, 1]
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported image format!")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image_and_mask(image, mask, target_size):
|
||||||
|
resized_image = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
|
||||||
|
resized_mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
|
||||||
|
return resized_image, resized_mask
|
||||||
|
|
||||||
|
|
||||||
|
def process_image(image_file, image_dir, mask_dir, output_mask_dir, output_image_dir, target_size):
|
||||||
|
image_path = os.path.join(image_dir, image_file)
|
||||||
|
mask_path = os.path.join(mask_dir, image_file.rsplit('.', 1)[0] + '_bin.png')
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
if image is None or mask is None:
|
||||||
|
raise ValueError(f"Image or mask not found for {image_file}")
|
||||||
|
|
||||||
|
resized_image, resized_mask = resize_image_and_mask(image, mask, target_size)
|
||||||
|
resized_mask = trans_255_1(resized_mask)
|
||||||
|
|
||||||
|
unique_id = str(uuid.uuid4())
|
||||||
|
mask_output_path = os.path.join(output_mask_dir, unique_id + '.png')
|
||||||
|
image_output_path = os.path.join(output_image_dir, unique_id + '.jpg')
|
||||||
|
|
||||||
|
cv2.imwrite(mask_output_path, resized_mask)
|
||||||
|
cv2.imwrite(image_output_path, resized_image)
|
||||||
|
|
||||||
|
return unique_id
|
||||||
|
|
||||||
|
|
||||||
|
def mask_to_unet(image_dir, mask_dir, output_mask_dir, output_image_dir, train_txt, val_txt, target_size,
|
||||||
|
num_images=2000):
|
||||||
|
image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||||||
|
|
||||||
|
if len(image_files) < num_images:
|
||||||
|
raise ValueError(f"Not enough images in directory to sample {num_images} images.")
|
||||||
|
|
||||||
|
# 随机选择 num_images 个文件
|
||||||
|
image_files = random.sample(image_files, num_images)
|
||||||
|
|
||||||
|
args = [(image_file, image_dir, mask_dir, output_mask_dir, output_image_dir, target_size) for image_file in
|
||||||
|
image_files]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
|
||||||
|
futures = [executor.submit(process_image, *arg) for arg in args]
|
||||||
|
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing images"):
|
||||||
|
try:
|
||||||
|
unique_id = future.result()
|
||||||
|
results.append(unique_id)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing image: {e}")
|
||||||
|
|
||||||
|
random.shuffle(results)
|
||||||
|
split_point = int(len(results) * 0.8)
|
||||||
|
train_ids, val_ids = results[:split_point], results[split_point:]
|
||||||
|
|
||||||
|
with open(train_txt, 'a') as train_file:
|
||||||
|
for uid in train_ids:
|
||||||
|
train_file.write(f"{uid}\n")
|
||||||
|
|
||||||
|
with open(val_txt, 'a') as val_file:
|
||||||
|
for uid in val_ids:
|
||||||
|
val_file.write(f"{uid}\n")
|
||||||
|
|
||||||
|
print(f"训练集文件名已写入 {train_txt}")
|
||||||
|
print(f"验证集文件名已写入 {val_txt}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
target_size = (1696, 864)
|
||||||
|
image_dir = r"E:\git\unet_seg\unet\VOCdevkit\VOC2007\original_data\dataset_A\train\img"
|
||||||
|
mask_dir = r"E:\git\unet_seg\unet\VOCdevkit\VOC2007\original_data\dataset_A\train\label"
|
||||||
|
output_mask_dir = "SegmentationClass"
|
||||||
|
output_image_dir = "JPEGImages"
|
||||||
|
output_txt_dir = './ImageSets/Segmentation'
|
||||||
|
|
||||||
|
train_txt = os.path.join(output_txt_dir, 'train.txt')
|
||||||
|
val_txt = os.path.join(output_txt_dir, 'val.txt')
|
||||||
|
|
||||||
|
os.makedirs(output_mask_dir, exist_ok=True)
|
||||||
|
os.makedirs(output_image_dir, exist_ok=True)
|
||||||
|
os.makedirs(output_txt_dir, exist_ok=True)
|
||||||
|
|
||||||
|
mask_to_unet(image_dir, mask_dir, output_mask_dir, output_image_dir, train_txt, val_txt, target_size,
|
||||||
|
num_images=len(os.listdir(image_dir)))
|
|
@ -0,0 +1,75 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
class U_ConvAutoencoder(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(U_ConvAutoencoder, self).__init__()
|
||||||
|
# 编码器
|
||||||
|
self.encoder1 = nn.Sequential(
|
||||||
|
nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # output: 16 x 1692 x 855
|
||||||
|
nn.BatchNorm2d(16),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.encoder2 = nn.Sequential(
|
||||||
|
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # output: 32 x 846 x 428
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.encoder3 = nn.Sequential(
|
||||||
|
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # output: 64 x 423 x 214
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.encoder4 = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # output: 128 x 212 x 107
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.encoder5 = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # output: 256 x 106 x 54
|
||||||
|
nn.BatchNorm2d(256),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解码器
|
||||||
|
self.decoder5 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||||
|
# output: 128 x 212 x 107
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.decoder4 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(256, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 64 x 423 x 214
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.decoder3 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 32 x 846 x 428
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.decoder2 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 16 x 1692 x 855
|
||||||
|
nn.BatchNorm2d(16),
|
||||||
|
nn.ReLU(True)
|
||||||
|
)
|
||||||
|
self.decoder1 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 1 x 3384 x 1710
|
||||||
|
nn.Sigmoid() # 使用Sigmoid以确保输出在[0, 1]范围内
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# 编码器
|
||||||
|
enc1 = self.encoder1(x)
|
||||||
|
enc2 = self.encoder2(enc1)
|
||||||
|
enc3 = self.encoder3(enc2)
|
||||||
|
enc4 = self.encoder4(enc3)
|
||||||
|
enc5 = self.encoder5(enc4)
|
||||||
|
|
||||||
|
# 解码器
|
||||||
|
dec5 = self.decoder5(enc5)
|
||||||
|
dec4 = self.decoder4(torch.cat([dec5, enc4], dim=1))
|
||||||
|
dec3 = self.decoder3(torch.cat([dec4, enc3], dim=1))
|
||||||
|
dec2 = self.decoder2(torch.cat([dec3, enc2], dim=1))
|
||||||
|
dec1 = self.decoder1(torch.cat([dec2, enc1], dim=1))
|
||||||
|
|
||||||
|
return dec1
|
|
@ -0,0 +1 @@
|
||||||
|
#
|
|
@ -0,0 +1,185 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.model_zoo as model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||||
|
base_width=64, dilation=1, norm_layer=None):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
if groups != 1 or base_width != 64:
|
||||||
|
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||||
|
if dilation > 1:
|
||||||
|
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||||
|
self.bn1 = norm_layer(planes)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = conv3x3(planes, planes)
|
||||||
|
self.bn2 = norm_layer(planes)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||||
|
base_width=64, dilation=1, norm_layer=None):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
width = int(planes * (base_width / 64.)) * groups
|
||||||
|
# 利用1x1卷积下降通道数
|
||||||
|
self.conv1 = conv1x1(inplanes, width)
|
||||||
|
self.bn1 = norm_layer(width)
|
||||||
|
# 利用3x3卷积进行特征提取
|
||||||
|
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||||
|
self.bn2 = norm_layer(width)
|
||||||
|
# 利用1x1卷积上升通道数
|
||||||
|
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||||
|
self.bn3 = norm_layer(planes * self.expansion)
|
||||||
|
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
def __init__(self, block, layers, num_classes=1000):
|
||||||
|
#-----------------------------------------------------------#
|
||||||
|
# 假设输入图像为600,600,3
|
||||||
|
# 当我们使用resnet50的时候
|
||||||
|
#-----------------------------------------------------------#
|
||||||
|
self.inplanes = 64
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
# 600,600,3 -> 300,300,64
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
# 300,300,64 -> 150,150,64
|
||||||
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
|
||||||
|
# 150,150,64 -> 150,150,256
|
||||||
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||||
|
# 150,150,256 -> 75,75,512
|
||||||
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||||
|
# 75,75,512 -> 38,38,1024
|
||||||
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||||
|
# 38,38,1024 -> 19,19,2048
|
||||||
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||||
|
|
||||||
|
self.avgpool = nn.AvgPool2d(7)
|
||||||
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||||
|
kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(planes * block.expansion),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for i in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x = self.conv1(x)
|
||||||
|
# x = self.bn1(x)
|
||||||
|
# x = self.relu(x)
|
||||||
|
# x = self.maxpool(x)
|
||||||
|
|
||||||
|
# x = self.layer1(x)
|
||||||
|
# x = self.layer2(x)
|
||||||
|
# x = self.layer3(x)
|
||||||
|
# x = self.layer4(x)
|
||||||
|
|
||||||
|
# x = self.avgpool(x)
|
||||||
|
# x = x.view(x.size(0), -1)
|
||||||
|
# x = self.fc(x)
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
feat1 = self.relu(x)
|
||||||
|
|
||||||
|
x = self.maxpool(feat1)
|
||||||
|
feat2 = self.layer1(x)
|
||||||
|
|
||||||
|
feat3 = self.layer2(feat2)
|
||||||
|
feat4 = self.layer3(feat3)
|
||||||
|
feat5 = self.layer4(feat4)
|
||||||
|
return [feat1, feat2, feat3, feat4, feat5]
|
||||||
|
|
||||||
|
def resnet50(pretrained=False, **kwargs):
|
||||||
|
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||||
|
if pretrained:
|
||||||
|
model.load_state_dict(model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', model_dir='model_data'), strict=False)
|
||||||
|
|
||||||
|
del model.avgpool
|
||||||
|
del model.fc
|
||||||
|
return model
|
|
@ -0,0 +1,94 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from nets.resnet import resnet50
|
||||||
|
from nets.vgg import VGG16
|
||||||
|
|
||||||
|
|
||||||
|
class unetUp(nn.Module):
|
||||||
|
def __init__(self, in_size, out_size):
|
||||||
|
super(unetUp, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(in_size, out_size, kernel_size = 3, padding = 1)
|
||||||
|
self.conv2 = nn.Conv2d(out_size, out_size, kernel_size = 3, padding = 1)
|
||||||
|
self.up = nn.UpsamplingBilinear2d(scale_factor = 2)
|
||||||
|
self.relu = nn.ReLU(inplace = True)
|
||||||
|
|
||||||
|
def forward(self, inputs1, inputs2):
|
||||||
|
outputs = torch.cat([inputs1, self.up(inputs2)], 1)
|
||||||
|
outputs = self.conv1(outputs)
|
||||||
|
outputs = self.relu(outputs)
|
||||||
|
outputs = self.conv2(outputs)
|
||||||
|
outputs = self.relu(outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class Unet(nn.Module):
|
||||||
|
def __init__(self, num_classes = 21, pretrained = False, backbone = 'vgg'):
|
||||||
|
super(Unet, self).__init__()
|
||||||
|
if backbone == 'vgg':
|
||||||
|
self.vgg = VGG16(pretrained = pretrained)
|
||||||
|
in_filters = [192, 384, 768, 1024]
|
||||||
|
elif backbone == "resnet50":
|
||||||
|
self.resnet = resnet50(pretrained = pretrained)
|
||||||
|
in_filters = [192, 512, 1024, 3072]
|
||||||
|
else:
|
||||||
|
raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone))
|
||||||
|
out_filters = [64, 128, 256, 512]
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
# 64,64,512
|
||||||
|
self.up_concat4 = unetUp(in_filters[3], out_filters[3])
|
||||||
|
# 128,128,256
|
||||||
|
self.up_concat3 = unetUp(in_filters[2], out_filters[2])
|
||||||
|
# 256,256,128
|
||||||
|
self.up_concat2 = unetUp(in_filters[1], out_filters[1])
|
||||||
|
# 512,512,64
|
||||||
|
self.up_concat1 = unetUp(in_filters[0], out_filters[0])
|
||||||
|
|
||||||
|
if backbone == 'resnet50':
|
||||||
|
self.up_conv = nn.Sequential(
|
||||||
|
nn.UpsamplingBilinear2d(scale_factor = 2),
|
||||||
|
nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.up_conv = None
|
||||||
|
|
||||||
|
self.final = nn.Conv2d(out_filters[0], num_classes, 1)
|
||||||
|
|
||||||
|
self.backbone = backbone
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
if self.backbone == "vgg":
|
||||||
|
[feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs)
|
||||||
|
elif self.backbone == "resnet50":
|
||||||
|
[feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs)
|
||||||
|
|
||||||
|
up4 = self.up_concat4(feat4, feat5)
|
||||||
|
up3 = self.up_concat3(feat3, up4)
|
||||||
|
up2 = self.up_concat2(feat2, up3)
|
||||||
|
up1 = self.up_concat1(feat1, up2)
|
||||||
|
|
||||||
|
if self.up_conv != None:
|
||||||
|
up1 = self.up_conv(up1)
|
||||||
|
|
||||||
|
final = self.final(up1)
|
||||||
|
|
||||||
|
return final
|
||||||
|
|
||||||
|
def freeze_backbone(self):
|
||||||
|
if self.backbone == "vgg":
|
||||||
|
for param in self.vgg.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
elif self.backbone == "resnet50":
|
||||||
|
for param in self.resnet.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def unfreeze_backbone(self):
|
||||||
|
if self.backbone == "vgg":
|
||||||
|
for param in self.vgg.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
elif self.backbone == "resnet50":
|
||||||
|
for param in self.resnet.parameters():
|
||||||
|
param.requires_grad = True
|
|
@ -0,0 +1,113 @@
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def CE_Loss(inputs, target, cls_weights, num_classes=21):
|
||||||
|
n, c, h, w = inputs.size()
|
||||||
|
nt, ht, wt = target.size()
|
||||||
|
if h != ht and w != wt:
|
||||||
|
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
|
||||||
|
|
||||||
|
temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
|
||||||
|
temp_target = target.view(-1)
|
||||||
|
|
||||||
|
CE_loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target)
|
||||||
|
return CE_loss
|
||||||
|
|
||||||
|
def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2):
|
||||||
|
n, c, h, w = inputs.size()
|
||||||
|
nt, ht, wt = target.size()
|
||||||
|
if h != ht and w != wt:
|
||||||
|
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
|
||||||
|
|
||||||
|
temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
|
||||||
|
temp_target = target.view(-1)
|
||||||
|
|
||||||
|
logpt = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target)
|
||||||
|
pt = torch.exp(logpt)
|
||||||
|
if alpha is not None:
|
||||||
|
logpt *= alpha
|
||||||
|
loss = -((1 - pt) ** gamma) * logpt
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
|
||||||
|
n, c, h, w = inputs.size()
|
||||||
|
nt, ht, wt, ct = target.size()
|
||||||
|
if h != ht and w != wt:
|
||||||
|
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
|
||||||
|
|
||||||
|
temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
|
||||||
|
temp_target = target.view(n, -1, ct)
|
||||||
|
|
||||||
|
#--------------------------------------------#
|
||||||
|
# 计算dice loss
|
||||||
|
#--------------------------------------------#
|
||||||
|
tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
|
||||||
|
fp = torch.sum(temp_inputs , axis=[0,1]) - tp
|
||||||
|
fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp
|
||||||
|
|
||||||
|
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
|
||||||
|
dice_loss = 1 - torch.mean(score)
|
||||||
|
return dice_loss
|
||||||
|
|
||||||
|
def weights_init(net, init_type='normal', init_gain=0.02):
|
||||||
|
def init_func(m):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if hasattr(m, 'weight') and classname.find('Conv') != -1:
|
||||||
|
if init_type == 'normal':
|
||||||
|
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
|
||||||
|
elif init_type == 'xavier':
|
||||||
|
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
|
||||||
|
elif init_type == 'kaiming':
|
||||||
|
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||||
|
elif init_type == 'orthogonal':
|
||||||
|
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||||
|
elif classname.find('BatchNorm2d') != -1:
|
||||||
|
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||||
|
torch.nn.init.constant_(m.bias.data, 0.0)
|
||||||
|
print('initialize network with %s type' % init_type)
|
||||||
|
net.apply(init_func)
|
||||||
|
|
||||||
|
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
|
||||||
|
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
|
||||||
|
if iters <= warmup_total_iters:
|
||||||
|
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
|
||||||
|
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
|
||||||
|
elif iters >= total_iters - no_aug_iter:
|
||||||
|
lr = min_lr
|
||||||
|
else:
|
||||||
|
lr = min_lr + 0.5 * (lr - min_lr) * (
|
||||||
|
1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
|
||||||
|
)
|
||||||
|
return lr
|
||||||
|
|
||||||
|
def step_lr(lr, decay_rate, step_size, iters):
|
||||||
|
if step_size < 1:
|
||||||
|
raise ValueError("step_size must above 1.")
|
||||||
|
n = iters // step_size
|
||||||
|
out_lr = lr * decay_rate ** n
|
||||||
|
return out_lr
|
||||||
|
|
||||||
|
if lr_decay_type == "cos":
|
||||||
|
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
|
||||||
|
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
|
||||||
|
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
|
||||||
|
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
|
||||||
|
else:
|
||||||
|
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
|
||||||
|
step_size = total_iters / step_num
|
||||||
|
func = partial(step_lr, lr, decay_rate, step_size)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
|
||||||
|
lr = lr_scheduler_func(epoch)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
|
@ -0,0 +1,75 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.hub import load_state_dict_from_url
|
||||||
|
|
||||||
|
|
||||||
|
class VGG(nn.Module):
|
||||||
|
def __init__(self, features, num_classes=1000):
|
||||||
|
super(VGG, self).__init__()
|
||||||
|
self.features = features
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Linear(512 * 7 * 7, 4096),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Dropout(),
|
||||||
|
nn.Linear(4096, 4096),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Dropout(),
|
||||||
|
nn.Linear(4096, num_classes),
|
||||||
|
)
|
||||||
|
self._initialize_weights()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x = self.features(x)
|
||||||
|
# x = self.avgpool(x)
|
||||||
|
# x = torch.flatten(x, 1)
|
||||||
|
# x = self.classifier(x)
|
||||||
|
feat1 = self.features[ :4 ](x)
|
||||||
|
feat2 = self.features[4 :9 ](feat1)
|
||||||
|
feat3 = self.features[9 :16](feat2)
|
||||||
|
feat4 = self.features[16:23](feat3)
|
||||||
|
feat5 = self.features[23:-1](feat4)
|
||||||
|
return [feat1, feat2, feat3, feat4, feat5]
|
||||||
|
|
||||||
|
def _initialize_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, 0, 0.01)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def make_layers(cfg, batch_norm=False, in_channels = 3):
|
||||||
|
layers = []
|
||||||
|
for v in cfg:
|
||||||
|
if v == 'M':
|
||||||
|
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||||
|
else:
|
||||||
|
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
||||||
|
if batch_norm:
|
||||||
|
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
||||||
|
else:
|
||||||
|
layers += [conv2d, nn.ReLU(inplace=True)]
|
||||||
|
in_channels = v
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
# 512,512,3 -> 512,512,64 -> 256,256,64 -> 256,256,128 -> 128,128,128 -> 128,128,256 -> 64,64,256
|
||||||
|
# 64,64,512 -> 32,32,512 -> 32,32,512
|
||||||
|
cfgs = {
|
||||||
|
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def VGG16(pretrained, in_channels = 3, **kwargs):
|
||||||
|
model = VGG(make_layers(cfgs["D"], batch_norm = False, in_channels = in_channels), **kwargs)
|
||||||
|
if pretrained:
|
||||||
|
state_dict = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth", model_dir="./model_data")
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
del model.avgpool
|
||||||
|
del model.classifier
|
||||||
|
return model
|
|
@ -0,0 +1,169 @@
|
||||||
|
import itertools
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
import cv2
|
||||||
|
from unet import Unet
|
||||||
|
from nets.U_ConvAutoencoder import U_ConvAutoencoder
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
|
|
||||||
|
# 定义卷积自编码器
|
||||||
|
|
||||||
|
class PreCA:
|
||||||
|
device: torch.device = None
|
||||||
|
model: U_ConvAutoencoder = None
|
||||||
|
transform: transforms.Compose = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize_model(cls, u_ca_path: str) -> None:
|
||||||
|
# 实例化模型并加载权重
|
||||||
|
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
cls.model = U_ConvAutoencoder().to(cls.device)
|
||||||
|
cls.model.load_state_dict(torch.load(u_ca_path, map_location=cls.device))
|
||||||
|
cls.model.eval()
|
||||||
|
# 图像预处理
|
||||||
|
cls.transform = transforms.Compose([
|
||||||
|
transforms.Resize((1728, 3392)),
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_image(cls, image: Image.Image) -> torch.Tensor:
|
||||||
|
image = image.convert("L")
|
||||||
|
image = cls.transform(image).unsqueeze(0) # 添加batch维度
|
||||||
|
return image.to(cls.device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ca_smooth(image: Image.Image) -> Image.Image:
|
||||||
|
image_cv2 = np.array(image)
|
||||||
|
# 对图像进行闭运算
|
||||||
|
closed_image = cv2.morphologyEx(image_cv2, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)))
|
||||||
|
# Step 1: 使用高斯模糊来平滑图像边缘
|
||||||
|
blurred = cv2.GaussianBlur(closed_image, (1, 1), 0)
|
||||||
|
th = cv2.threshold(blurred, 126, 255, cv2.THRESH_BINARY)[1]
|
||||||
|
|
||||||
|
eroded_image_pil = Image.fromarray(th)
|
||||||
|
return eroded_image_pil
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def infer(cls, image: Image.Image) -> Image.Image:
|
||||||
|
image = cls.load_image(image)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = cls.model(image)
|
||||||
|
output = output.squeeze(0).cpu() # 去除batch维度并移动到CPU
|
||||||
|
output_image = transforms.ToPILImage()(output)
|
||||||
|
output_image = output_image.resize((3384, 1710), Image.NEAREST)
|
||||||
|
return output_image
|
||||||
|
|
||||||
|
|
||||||
|
class PreUnet:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def blend_images_with_colorize(image1: Image.Image, image2: Image.Image, alpha: float = 0.5) -> None:
|
||||||
|
red_image1 = ImageOps.colorize(image1.convert("L"), (0, 0, 0), (255, 0, 0))
|
||||||
|
green_image2 = ImageOps.colorize(image2.convert("L"), (0, 0, 0), (0, 255, 0))
|
||||||
|
blended_image = Image.blend(red_image1, green_image2, alpha)
|
||||||
|
blended_image.show()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_metrics(pred_image: Image.Image, true_image: Image.Image, threshold: int = 1) -> Tuple[int, int, int]:
|
||||||
|
pred_gray = pred_image.convert('L')
|
||||||
|
true_gray = true_image.convert('L')
|
||||||
|
|
||||||
|
pred_binary = pred_gray.point(lambda x: 0 if x < threshold else 255)
|
||||||
|
true_binary = true_gray.point(lambda x: 0 if x < threshold else 255)
|
||||||
|
|
||||||
|
pred_array = np.array(pred_binary)
|
||||||
|
true_array = np.array(true_binary)
|
||||||
|
|
||||||
|
# Calculate TP, FP, FN
|
||||||
|
TP = np.sum((pred_array == 255) & (true_array == 255))
|
||||||
|
FP = np.sum((pred_array == 255) & (true_array == 0))
|
||||||
|
FN = np.sum((pred_array == 0) & (true_array == 255))
|
||||||
|
|
||||||
|
return TP, FP, FN
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_mask(original_image, mask_imag):
|
||||||
|
# 打开原图和mask图片
|
||||||
|
original_image = original_image.convert("RGB")
|
||||||
|
mask_image = mask_imag.convert("RGB")
|
||||||
|
|
||||||
|
# 获取图片的像素数据
|
||||||
|
original_pixels = original_image.load()
|
||||||
|
mask_pixels = mask_image.load()
|
||||||
|
|
||||||
|
# 获取图片的尺寸
|
||||||
|
width, height = original_image.size
|
||||||
|
|
||||||
|
# 遍历每个像素
|
||||||
|
for y in range(height):
|
||||||
|
for x in range(width):
|
||||||
|
# 如果mask的像素是白色 (255, 255, 255)
|
||||||
|
if mask_pixels[x, y] == (255, 255, 255):
|
||||||
|
# 将原图中的对应像素改为绿色 (0, 255, 0)
|
||||||
|
original_pixels[x, y] = (0, 255, 0)
|
||||||
|
|
||||||
|
# 保存结果图片
|
||||||
|
return original_image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def main(cls, ca_path: str) -> None:
|
||||||
|
PreCA.initialize_model(ca_path)
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
ious: List[float] = []
|
||||||
|
img_names: List[str] = os.listdir(dir_origin_path)
|
||||||
|
for img_name in tqdm(img_names):
|
||||||
|
if img_name.lower().endswith(
|
||||||
|
('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
|
||||||
|
image_path = os.path.join(dir_origin_path, img_name)
|
||||||
|
image = Image.open(image_path)
|
||||||
|
|
||||||
|
r_image = unet.detect_image(image)
|
||||||
|
r_image = PreCA.infer(r_image) # 自编码器
|
||||||
|
r_image = PreCA.ca_smooth(r_image)
|
||||||
|
|
||||||
|
|
||||||
|
if is_save:
|
||||||
|
if not os.path.exists(dir_save_path):
|
||||||
|
os.makedirs(dir_save_path)
|
||||||
|
r_image.save(os.path.join(dir_save_path, img_name.split('.')[0] + '_bin.png'))
|
||||||
|
if is_get_iou:
|
||||||
|
label_path = os.path.join(dir_label_path, img_name.split('.')[0] + '_bin.png')
|
||||||
|
label = Image.open(label_path)
|
||||||
|
TP, FP, FN = cls.calculate_metrics(r_image, label)
|
||||||
|
iou = TP / (TP + FP + FN)
|
||||||
|
ious.append(iou)
|
||||||
|
print(f"当前iou{iou}")
|
||||||
|
|
||||||
|
# cls.blend_images_with_colorize(label, r_image)
|
||||||
|
|
||||||
|
if is_get_iou: print(f"平均iou{np.mean(ious)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
name_classes: List[str] = ["background", "lane"]
|
||||||
|
dir_origin_path: str = r"E:\git\unet_seg\unet\original_data\dataset_A\test\img"
|
||||||
|
# 是否计算IOU,若为True必须填写dir_label_path(label的路径)
|
||||||
|
is_get_iou: bool = True
|
||||||
|
dir_label_path: str = r"E:\git\unet_seg\unet\original_data\dataset_A\test\Label"
|
||||||
|
# 是否保存预测后的图像,若为True必须填写dir_save_path(保存路径的路径)
|
||||||
|
is_save: bool = False
|
||||||
|
dir_save_path: str = "img_out/"
|
||||||
|
# 设置多尺度监督自编码器的权重路径
|
||||||
|
u_ca_path: str = 'weights/best_conv_autoencoder1.pth'
|
||||||
|
_defaults: dict = {
|
||||||
|
"model_path": 'model_data/best80.pth', # U-Net权重地址
|
||||||
|
"num_classes": 2, # 预测类别算上背景为2
|
||||||
|
"backbone": "vgg",
|
||||||
|
"input_shape": [1696, 864], # 图像大小
|
||||||
|
"mix_type": 1,
|
||||||
|
"cuda": True, # 是否启用cuda加速
|
||||||
|
}
|
||||||
|
unet: Unet = Unet(_defaults)
|
||||||
|
PreUnet.main(u_ca_path)
|
|
@ -0,0 +1,12 @@
|
||||||
|
matplotlib==3.1.2
|
||||||
|
numpy==1.21.6
|
||||||
|
opencv_python==4.1.2.30
|
||||||
|
Pillow==8.2.0
|
||||||
|
Pillow==10.4.0
|
||||||
|
scipy==1.2.1
|
||||||
|
streamlit==1.23.1
|
||||||
|
thop==0.1.1.post2209072238
|
||||||
|
torch
|
||||||
|
torchsummary
|
||||||
|
torchvision
|
||||||
|
tqdm==4.60.0
|
|
@ -0,0 +1,30 @@
|
||||||
|
#--------------------------------------------#
|
||||||
|
# 该部分代码用于看网络结构
|
||||||
|
#--------------------------------------------#
|
||||||
|
import torch
|
||||||
|
from thop import clever_format, profile
|
||||||
|
from torchsummary import summary
|
||||||
|
|
||||||
|
from nets.unet import Unet
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
input_shape = [1024, 1024]
|
||||||
|
num_classes = 2
|
||||||
|
backbone = 'resnet50'
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = Unet(num_classes = num_classes, backbone = backbone).to(device)
|
||||||
|
summary(model, (3, input_shape[0], input_shape[1]))
|
||||||
|
|
||||||
|
dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
|
||||||
|
flops, params = profile(model.to(device), (dummy_input, ), verbose=False)
|
||||||
|
#--------------------------------------------------------#
|
||||||
|
# flops * 2是因为profile没有将卷积作为两个operations
|
||||||
|
# 有些论文将卷积算乘法、加法两个operations。此时乘2
|
||||||
|
# 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
|
||||||
|
# 本代码选择乘2,参考YOLOX。
|
||||||
|
#--------------------------------------------------------#
|
||||||
|
flops = flops * 2
|
||||||
|
flops, params = clever_format([flops, params], "%.3f")
|
||||||
|
print('Total GFLOPS: %s' % (flops))
|
||||||
|
print('Total params: %s' % (params))
|
|
@ -0,0 +1,255 @@
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from nets.unet import Unet
|
||||||
|
from nets.unet_training import get_lr_scheduler, set_optimizer_lr, weights_init
|
||||||
|
from utils.callbacks import EvalCallback, LossHistory
|
||||||
|
from utils.dataloader import UnetDataset, unet_dataset_collate
|
||||||
|
from utils.utils import (download_weights, seed_everything, show_config,
|
||||||
|
worker_init_fn)
|
||||||
|
from utils.utils_fit import fit_one_epoch
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
if distributed:
|
||||||
|
dist.init_process_group(backend="nccl")
|
||||||
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
device = torch.device("cuda", local_rank)
|
||||||
|
if local_rank == 0:
|
||||||
|
print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...")
|
||||||
|
print("Gpu Device Count : ", ngpus_per_node)
|
||||||
|
else:
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
local_rank = 0
|
||||||
|
rank = 0
|
||||||
|
|
||||||
|
if pretrained:
|
||||||
|
if distributed:
|
||||||
|
if local_rank == 0:
|
||||||
|
download_weights(backbone)
|
||||||
|
dist.barrier()
|
||||||
|
else:
|
||||||
|
download_weights(backbone)
|
||||||
|
|
||||||
|
model = Unet(num_classes=num_classes, pretrained=pretrained, backbone=backbone).train()
|
||||||
|
if not pretrained:
|
||||||
|
weights_init(model)
|
||||||
|
if model_path != '':
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
print('Load weights {}.'.format(model_path))
|
||||||
|
|
||||||
|
model_dict = model.state_dict()
|
||||||
|
pretrained_dict = torch.load(model_path, map_location=device)
|
||||||
|
load_key, no_load_key, temp_dict = [], [], {}
|
||||||
|
for k, v in pretrained_dict.items():
|
||||||
|
if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
|
||||||
|
temp_dict[k] = v
|
||||||
|
load_key.append(k)
|
||||||
|
else:
|
||||||
|
no_load_key.append(k)
|
||||||
|
model_dict.update(temp_dict)
|
||||||
|
model.load_state_dict(model_dict)
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key))
|
||||||
|
print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key))
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
time_str = datetime.datetime.strftime(datetime.datetime.now(), '%Y_%m_%d_%H_%M_%S')
|
||||||
|
log_dir = os.path.join(save_dir, "loss_" + str(time_str))
|
||||||
|
loss_history = LossHistory(log_dir, model, input_shape=input_shape)
|
||||||
|
else:
|
||||||
|
loss_history = None
|
||||||
|
|
||||||
|
if fp16:
|
||||||
|
from torch.cuda.amp import GradScaler as GradScaler
|
||||||
|
|
||||||
|
scaler = GradScaler()
|
||||||
|
else:
|
||||||
|
scaler = None
|
||||||
|
|
||||||
|
model_train = model.train()
|
||||||
|
|
||||||
|
if sync_bn and ngpus_per_node > 1 and distributed:
|
||||||
|
model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train)
|
||||||
|
elif sync_bn:
|
||||||
|
print("Sync_bn is not support in one gpu or not distributed.")
|
||||||
|
|
||||||
|
if Cuda:
|
||||||
|
if distributed:
|
||||||
|
|
||||||
|
model_train = model_train.cuda(local_rank)
|
||||||
|
model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank],
|
||||||
|
find_unused_parameters=True)
|
||||||
|
else:
|
||||||
|
model_train = torch.nn.DataParallel(model)
|
||||||
|
cudnn.benchmark = True
|
||||||
|
model_train = model_train.cuda()
|
||||||
|
|
||||||
|
with open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/train.txt"), "r") as f:
|
||||||
|
train_lines = f.readlines()
|
||||||
|
with open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"), "r") as f:
|
||||||
|
val_lines = f.readlines()
|
||||||
|
num_train = len(train_lines)
|
||||||
|
num_val = len(val_lines)
|
||||||
|
|
||||||
|
if True:
|
||||||
|
UnFreeze_flag = False
|
||||||
|
|
||||||
|
if Freeze_Train:
|
||||||
|
model.freeze_backbone()
|
||||||
|
|
||||||
|
batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size
|
||||||
|
|
||||||
|
nbs = 16
|
||||||
|
lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1
|
||||||
|
lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4
|
||||||
|
Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
|
||||||
|
Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
|
||||||
|
|
||||||
|
optimizer = {
|
||||||
|
'adam': optim.Adam(model.parameters(), Init_lr_fit, betas=(momentum, 0.999), weight_decay=weight_decay),
|
||||||
|
'sgd': optim.SGD(model.parameters(), Init_lr_fit, momentum=momentum, nesterov=True,
|
||||||
|
weight_decay=weight_decay)
|
||||||
|
}[optimizer_type]
|
||||||
|
|
||||||
|
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
|
||||||
|
|
||||||
|
epoch_step = num_train // batch_size
|
||||||
|
epoch_step_val = num_val // batch_size
|
||||||
|
|
||||||
|
if epoch_step == 0 or epoch_step_val == 0:
|
||||||
|
raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
|
||||||
|
|
||||||
|
train_dataset = UnetDataset(train_lines, input_shape, num_classes, True, VOCdevkit_path)
|
||||||
|
val_dataset = UnetDataset(val_lines, input_shape, num_classes, False, VOCdevkit_path)
|
||||||
|
|
||||||
|
if distributed:
|
||||||
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, )
|
||||||
|
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, )
|
||||||
|
batch_size = batch_size // ngpus_per_node
|
||||||
|
shuffle = False
|
||||||
|
else:
|
||||||
|
train_sampler = None
|
||||||
|
val_sampler = None
|
||||||
|
shuffle = True
|
||||||
|
|
||||||
|
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True, collate_fn=unet_dataset_collate, sampler=train_sampler,
|
||||||
|
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
|
||||||
|
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True, collate_fn=unet_dataset_collate, sampler=val_sampler,
|
||||||
|
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
eval_callback = EvalCallback(model, input_shape, num_classes, val_lines, VOCdevkit_path, log_dir, Cuda, \
|
||||||
|
eval_flag=eval_flag, period=eval_period)
|
||||||
|
else:
|
||||||
|
eval_callback = None
|
||||||
|
|
||||||
|
for epoch in range(Init_Epoch, UnFreeze_Epoch):
|
||||||
|
|
||||||
|
if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train:
|
||||||
|
batch_size = Unfreeze_batch_size
|
||||||
|
|
||||||
|
nbs = 16
|
||||||
|
lr_limit_max = 1e-4 if optimizer_type == 'adam' else 1e-1
|
||||||
|
lr_limit_min = 1e-4 if optimizer_type == 'adam' else 5e-4
|
||||||
|
Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
|
||||||
|
Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
|
||||||
|
|
||||||
|
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
|
||||||
|
|
||||||
|
model.unfreeze_backbone()
|
||||||
|
|
||||||
|
epoch_step = num_train // batch_size
|
||||||
|
epoch_step_val = num_val // batch_size
|
||||||
|
|
||||||
|
if epoch_step == 0 or epoch_step_val == 0:
|
||||||
|
raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
|
||||||
|
|
||||||
|
if distributed:
|
||||||
|
batch_size = batch_size // ngpus_per_node
|
||||||
|
|
||||||
|
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True, collate_fn=unet_dataset_collate, sampler=train_sampler,
|
||||||
|
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
|
||||||
|
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True, collate_fn=unet_dataset_collate, sampler=val_sampler,
|
||||||
|
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
|
||||||
|
|
||||||
|
UnFreeze_flag = True
|
||||||
|
|
||||||
|
if distributed:
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
|
||||||
|
|
||||||
|
fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch,
|
||||||
|
epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, dice_loss, focal_loss,
|
||||||
|
cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank)
|
||||||
|
|
||||||
|
if distributed:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
loss_history.writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Cuda = True
|
||||||
|
seed = 11
|
||||||
|
# 是否开启多卡
|
||||||
|
distributed = False
|
||||||
|
sync_bn = False
|
||||||
|
# 是否使用半精度
|
||||||
|
fp16 = True
|
||||||
|
num_classes = 2
|
||||||
|
# 设置骨干网络
|
||||||
|
backbone = "vgg"
|
||||||
|
pretrained = False
|
||||||
|
model_path = "model_data/8414_8376.pth"
|
||||||
|
input_shape = [1696, 864]
|
||||||
|
# 冻结训练
|
||||||
|
Init_Epoch = 0
|
||||||
|
Freeze_Epoch = 10
|
||||||
|
Freeze_batch_size = 1
|
||||||
|
# 解冻训练
|
||||||
|
UnFreeze_Epoch = 70
|
||||||
|
Unfreeze_batch_size = 1
|
||||||
|
Freeze_Train = True
|
||||||
|
# 学习率设置
|
||||||
|
Init_lr = 1e-4
|
||||||
|
Min_lr = Init_lr * 0.01
|
||||||
|
# 优化器
|
||||||
|
optimizer_type = "adam"
|
||||||
|
momentum = 0.9
|
||||||
|
weight_decay = 0
|
||||||
|
lr_decay_type = 'cos'
|
||||||
|
save_period = 5
|
||||||
|
save_dir = 'logs'
|
||||||
|
eval_flag = True
|
||||||
|
eval_period = 5
|
||||||
|
# 数据集设置
|
||||||
|
VOCdevkit_path = 'VOCdevkit'
|
||||||
|
dice_loss = False
|
||||||
|
focal_loss = False
|
||||||
|
cls_weights = np.ones([num_classes], np.float32)
|
||||||
|
num_workers = 0
|
||||||
|
seed_everything(seed)
|
||||||
|
ngpus_per_node = torch.cuda.device_count()
|
||||||
|
train()
|
|
@ -0,0 +1,131 @@
|
||||||
|
import colorsys
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from nets.unet import Unet as unet
|
||||||
|
from utils.utils import cvtColor, preprocess_input, resize_image
|
||||||
|
|
||||||
|
|
||||||
|
class Unet(object):
|
||||||
|
_defaults = {
|
||||||
|
"model_path": None,
|
||||||
|
"num_classes": 2,
|
||||||
|
"backbone": "vgg",
|
||||||
|
"input_shape": [1696, 864],
|
||||||
|
"mix_type": 1,
|
||||||
|
"cuda": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, _defaults,**kwargs):
|
||||||
|
self._defaults = _defaults
|
||||||
|
self.__dict__.update(self._defaults)
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
setattr(self, name, value)
|
||||||
|
|
||||||
|
if self.num_classes <= 2:
|
||||||
|
self.colors = [(0, 0, 0), (255,255,255)]
|
||||||
|
else:
|
||||||
|
hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
|
||||||
|
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
|
||||||
|
self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
|
||||||
|
|
||||||
|
self.generate()
|
||||||
|
|
||||||
|
|
||||||
|
def generate(self, onnx=False):
|
||||||
|
self.net = unet(num_classes=self.num_classes, backbone=self.backbone)
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
|
||||||
|
self.net = self.net.eval()
|
||||||
|
print('{} model, and classes loaded.'.format(self.model_path))
|
||||||
|
if not onnx:
|
||||||
|
if self.cuda:
|
||||||
|
self.net = nn.DataParallel(self.net)
|
||||||
|
self.net = self.net.cuda()
|
||||||
|
|
||||||
|
def detect_image(self, image, count=False, name_classes=None):
|
||||||
|
|
||||||
|
image = cvtColor(image)
|
||||||
|
|
||||||
|
old_img = copy.deepcopy(image)
|
||||||
|
orininal_h = np.array(image).shape[0]
|
||||||
|
orininal_w = np.array(image).shape[1]
|
||||||
|
|
||||||
|
image_data, nw, nh = resize_image(image, (self.input_shape[1], self.input_shape[0]))
|
||||||
|
|
||||||
|
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
images = torch.from_numpy(image_data)
|
||||||
|
if self.cuda:
|
||||||
|
images = images.cuda()
|
||||||
|
|
||||||
|
pr = self.net(images)[0]
|
||||||
|
|
||||||
|
pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()
|
||||||
|
|
||||||
|
pr = pr[int((self.input_shape[0] - nh) // 2): int((self.input_shape[0] - nh) // 2 + nh), \
|
||||||
|
int((self.input_shape[1] - nw) // 2): int((self.input_shape[1] - nw) // 2 + nw)]
|
||||||
|
|
||||||
|
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
pr = pr.argmax(axis=-1)
|
||||||
|
|
||||||
|
if count:
|
||||||
|
classes_nums = np.zeros([self.num_classes])
|
||||||
|
total_points_num = orininal_h * orininal_w
|
||||||
|
print('-' * 63)
|
||||||
|
print("|%25s | %15s | %15s|" % ("Key", "Value", "Ratio"))
|
||||||
|
print('-' * 63)
|
||||||
|
for i in range(self.num_classes):
|
||||||
|
num = np.sum(pr == i)
|
||||||
|
ratio = num / total_points_num * 100
|
||||||
|
if num > 0:
|
||||||
|
print("|%25s | %15s | %14.2f%%|" % (str(name_classes[i]), str(num), ratio))
|
||||||
|
print('-' * 63)
|
||||||
|
classes_nums[i] = num
|
||||||
|
print("classes_nums:", classes_nums)
|
||||||
|
|
||||||
|
if self.mix_type == 0:
|
||||||
|
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
|
||||||
|
# for c in range(self.num_classes):
|
||||||
|
# seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
|
||||||
|
# seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
|
||||||
|
# seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
|
||||||
|
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
|
||||||
|
# ------------------------------------------------#
|
||||||
|
# 将新图片转换成Image的形式
|
||||||
|
# ------------------------------------------------#
|
||||||
|
image = Image.fromarray(np.uint8(seg_img))
|
||||||
|
# ------------------------------------------------#
|
||||||
|
# 将新图与原图及进行混合
|
||||||
|
# ------------------------------------------------#
|
||||||
|
image = Image.blend(old_img, image, 0.7)
|
||||||
|
|
||||||
|
elif self.mix_type == 1:
|
||||||
|
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
|
||||||
|
# for c in range(self.num_classes):
|
||||||
|
# seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
|
||||||
|
# seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
|
||||||
|
# seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
|
||||||
|
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
|
||||||
|
# ------------------------------------------------#
|
||||||
|
# 将新图片转换成Image的形式
|
||||||
|
# ------------------------------------------------#
|
||||||
|
image = Image.fromarray(np.uint8(seg_img))
|
||||||
|
|
||||||
|
elif self.mix_type == 2:
|
||||||
|
seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8')
|
||||||
|
# ------------------------------------------------#
|
||||||
|
# 将新图片转换成Image的形式
|
||||||
|
# ------------------------------------------------#
|
||||||
|
image = Image.fromarray(np.uint8(seg_img))
|
||||||
|
|
||||||
|
return image
|
|
@ -0,0 +1 @@
|
||||||
|
#
|
|
@ -0,0 +1,210 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import scipy.signal
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import shutil
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from .utils import cvtColor, preprocess_input, resize_image
|
||||||
|
from .utils_metrics import compute_mIoU
|
||||||
|
|
||||||
|
|
||||||
|
class LossHistory():
|
||||||
|
def __init__(self, log_dir, model, input_shape, val_loss_flag=True):
|
||||||
|
self.log_dir = log_dir
|
||||||
|
self.val_loss_flag = val_loss_flag
|
||||||
|
|
||||||
|
self.losses = []
|
||||||
|
if self.val_loss_flag:
|
||||||
|
self.val_loss = []
|
||||||
|
|
||||||
|
os.makedirs(self.log_dir)
|
||||||
|
self.writer = SummaryWriter(self.log_dir)
|
||||||
|
try:
|
||||||
|
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
|
||||||
|
self.writer.add_graph(model, dummy_input)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def append_loss(self, epoch, loss, val_loss = None):
|
||||||
|
if not os.path.exists(self.log_dir):
|
||||||
|
os.makedirs(self.log_dir)
|
||||||
|
|
||||||
|
self.losses.append(loss)
|
||||||
|
if self.val_loss_flag:
|
||||||
|
self.val_loss.append(val_loss)
|
||||||
|
|
||||||
|
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
|
||||||
|
f.write(str(loss))
|
||||||
|
f.write("\n")
|
||||||
|
if self.val_loss_flag:
|
||||||
|
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
|
||||||
|
f.write(str(val_loss))
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
self.writer.add_scalar('loss', loss, epoch)
|
||||||
|
if self.val_loss_flag:
|
||||||
|
self.writer.add_scalar('val_loss', val_loss, epoch)
|
||||||
|
|
||||||
|
self.loss_plot()
|
||||||
|
|
||||||
|
def loss_plot(self):
|
||||||
|
iters = range(len(self.losses))
|
||||||
|
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
|
||||||
|
if self.val_loss_flag:
|
||||||
|
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
|
||||||
|
|
||||||
|
try:
|
||||||
|
if len(self.losses) < 25:
|
||||||
|
num = 5
|
||||||
|
else:
|
||||||
|
num = 15
|
||||||
|
|
||||||
|
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
|
||||||
|
if self.val_loss_flag:
|
||||||
|
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
plt.grid(True)
|
||||||
|
plt.xlabel('Epoch')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend(loc="upper right")
|
||||||
|
|
||||||
|
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
|
||||||
|
|
||||||
|
plt.cla()
|
||||||
|
plt.close("all")
|
||||||
|
|
||||||
|
class EvalCallback():
|
||||||
|
def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \
|
||||||
|
miou_out_path=".temp_miou_out", eval_flag=True, period=1):
|
||||||
|
super(EvalCallback, self).__init__()
|
||||||
|
|
||||||
|
self.net = net
|
||||||
|
self.input_shape = input_shape
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.image_ids = image_ids
|
||||||
|
self.dataset_path = dataset_path
|
||||||
|
self.log_dir = log_dir
|
||||||
|
self.cuda = cuda
|
||||||
|
self.miou_out_path = miou_out_path
|
||||||
|
self.eval_flag = eval_flag
|
||||||
|
self.period = period
|
||||||
|
|
||||||
|
self.image_ids = [image_id.split()[0] for image_id in image_ids]
|
||||||
|
self.mious = [0]
|
||||||
|
self.epoches = [0]
|
||||||
|
if self.eval_flag:
|
||||||
|
with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
|
||||||
|
f.write(str(0))
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
def get_miou_png(self, image):
|
||||||
|
#---------------------------------------------------------#
|
||||||
|
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
||||||
|
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
||||||
|
#---------------------------------------------------------#
|
||||||
|
image = cvtColor(image)
|
||||||
|
orininal_h = np.array(image).shape[0]
|
||||||
|
orininal_w = np.array(image).shape[1]
|
||||||
|
#---------------------------------------------------------#
|
||||||
|
# 给图像增加灰条,实现不失真的resize
|
||||||
|
# 也可以直接resize进行识别
|
||||||
|
#---------------------------------------------------------#
|
||||||
|
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
|
||||||
|
#---------------------------------------------------------#
|
||||||
|
# 添加上batch_size维度
|
||||||
|
#---------------------------------------------------------#
|
||||||
|
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
images = torch.from_numpy(image_data)
|
||||||
|
if self.cuda:
|
||||||
|
images = images.cuda()
|
||||||
|
|
||||||
|
#---------------------------------------------------#
|
||||||
|
# 图片传入网络进行预测
|
||||||
|
#---------------------------------------------------#
|
||||||
|
pr = self.net(images)[0]
|
||||||
|
#---------------------------------------------------#
|
||||||
|
# 取出每一个像素点的种类
|
||||||
|
#---------------------------------------------------#
|
||||||
|
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
|
||||||
|
#--------------------------------------#
|
||||||
|
# 将灰条部分截取掉
|
||||||
|
#--------------------------------------#
|
||||||
|
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
|
||||||
|
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
|
||||||
|
#---------------------------------------------------#
|
||||||
|
# 进行图片的resize
|
||||||
|
#---------------------------------------------------#
|
||||||
|
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
|
||||||
|
#---------------------------------------------------#
|
||||||
|
# 取出每一个像素点的种类
|
||||||
|
#---------------------------------------------------#
|
||||||
|
pr = pr.argmax(axis=-1)
|
||||||
|
|
||||||
|
image = Image.fromarray(np.uint8(pr))
|
||||||
|
return image
|
||||||
|
|
||||||
|
def on_epoch_end(self, epoch, model_eval):
|
||||||
|
if epoch % self.period == 0 and self.eval_flag:
|
||||||
|
self.net = model_eval
|
||||||
|
gt_dir = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/")
|
||||||
|
pred_dir = os.path.join(self.miou_out_path, 'detection-results')
|
||||||
|
if not os.path.exists(self.miou_out_path):
|
||||||
|
os.makedirs(self.miou_out_path)
|
||||||
|
if not os.path.exists(pred_dir):
|
||||||
|
os.makedirs(pred_dir)
|
||||||
|
print("Get miou.")
|
||||||
|
for image_id in tqdm(self.image_ids):
|
||||||
|
#-------------------------------#
|
||||||
|
# 从文件中读取图像
|
||||||
|
#-------------------------------#
|
||||||
|
image_path = os.path.join(self.dataset_path, "VOC2007/JPEGImages/"+image_id+".jpg")
|
||||||
|
image = Image.open(image_path)
|
||||||
|
#------------------------------#
|
||||||
|
# 获得预测txt
|
||||||
|
#------------------------------#
|
||||||
|
image = self.get_miou_png(image)
|
||||||
|
image.save(os.path.join(pred_dir, image_id + ".png"))
|
||||||
|
|
||||||
|
print("Calculate miou.")
|
||||||
|
_, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数
|
||||||
|
temp_miou = np.nanmean(IoUs) * 100
|
||||||
|
|
||||||
|
self.mious.append(temp_miou)
|
||||||
|
self.epoches.append(epoch)
|
||||||
|
|
||||||
|
with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
|
||||||
|
f.write(str(temp_miou))
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou')
|
||||||
|
|
||||||
|
plt.grid(True)
|
||||||
|
plt.xlabel('Epoch')
|
||||||
|
plt.ylabel('Miou')
|
||||||
|
plt.title('A Miou Curve')
|
||||||
|
plt.legend(loc="upper right")
|
||||||
|
|
||||||
|
plt.savefig(os.path.join(self.log_dir, "epoch_miou.png"))
|
||||||
|
plt.cla()
|
||||||
|
plt.close("all")
|
||||||
|
|
||||||
|
print("Get miou done.")
|
||||||
|
shutil.rmtree(self.miou_out_path)
|
|
@ -0,0 +1,149 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
|
||||||
|
from utils.utils import cvtColor, preprocess_input
|
||||||
|
|
||||||
|
|
||||||
|
class UnetDataset(Dataset):
|
||||||
|
def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
|
||||||
|
super(UnetDataset, self).__init__()
|
||||||
|
self.annotation_lines = annotation_lines
|
||||||
|
self.length = len(annotation_lines)
|
||||||
|
self.input_shape = input_shape
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.train = train
|
||||||
|
self.dataset_path = dataset_path
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
annotation_line = self.annotation_lines[index]
|
||||||
|
name = annotation_line.split()[0]
|
||||||
|
|
||||||
|
#-------------------------------#
|
||||||
|
# 从文件中读取图像
|
||||||
|
#-------------------------------#
|
||||||
|
jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg"))
|
||||||
|
png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png"))
|
||||||
|
#-------------------------------#
|
||||||
|
# 数据增强
|
||||||
|
#-------------------------------#
|
||||||
|
jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)
|
||||||
|
|
||||||
|
jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
|
||||||
|
png = np.array(png)
|
||||||
|
png[png >= self.num_classes] = self.num_classes
|
||||||
|
#-------------------------------------------------------#
|
||||||
|
# 转化成one_hot的形式
|
||||||
|
# 在这里需要+1是因为voc数据集有些标签具有白边部分
|
||||||
|
# 我们需要将白边部分进行忽略,+1的目的是方便忽略。
|
||||||
|
#-------------------------------------------------------#
|
||||||
|
seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])]
|
||||||
|
seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
|
||||||
|
|
||||||
|
return jpg, png, seg_labels
|
||||||
|
|
||||||
|
def rand(self, a=0, b=1):
|
||||||
|
return np.random.rand() * (b - a) + a
|
||||||
|
|
||||||
|
def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True):
|
||||||
|
image = cvtColor(image)
|
||||||
|
label = Image.fromarray(np.array(label))
|
||||||
|
#------------------------------#
|
||||||
|
# 获得图像的高宽与目标高宽
|
||||||
|
#------------------------------#
|
||||||
|
iw, ih = image.size
|
||||||
|
h, w = input_shape
|
||||||
|
|
||||||
|
if not random:
|
||||||
|
iw, ih = image.size
|
||||||
|
scale = min(w/iw, h/ih)
|
||||||
|
nw = int(iw*scale)
|
||||||
|
nh = int(ih*scale)
|
||||||
|
|
||||||
|
image = image.resize((nw,nh), Image.BICUBIC)
|
||||||
|
new_image = Image.new('RGB', [w, h], (128,128,128))
|
||||||
|
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
|
||||||
|
|
||||||
|
label = label.resize((nw,nh), Image.NEAREST)
|
||||||
|
new_label = Image.new('L', [w, h], (0))
|
||||||
|
new_label.paste(label, ((w-nw)//2, (h-nh)//2))
|
||||||
|
return new_image, new_label
|
||||||
|
|
||||||
|
#------------------------------------------#
|
||||||
|
# 对图像进行缩放并且进行长和宽的扭曲
|
||||||
|
#------------------------------------------#
|
||||||
|
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
|
||||||
|
scale = self.rand(0.25, 2)
|
||||||
|
if new_ar < 1:
|
||||||
|
nh = int(scale*h)
|
||||||
|
nw = int(nh*new_ar)
|
||||||
|
else:
|
||||||
|
nw = int(scale*w)
|
||||||
|
nh = int(nw/new_ar)
|
||||||
|
image = image.resize((nw,nh), Image.BICUBIC)
|
||||||
|
label = label.resize((nw,nh), Image.NEAREST)
|
||||||
|
|
||||||
|
#------------------------------------------#
|
||||||
|
# 翻转图像
|
||||||
|
#------------------------------------------#
|
||||||
|
flip = self.rand()<.5
|
||||||
|
if flip:
|
||||||
|
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
label = label.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
|
||||||
|
#------------------------------------------#
|
||||||
|
# 将图像多余的部分加上灰条
|
||||||
|
#------------------------------------------#
|
||||||
|
dx = int(self.rand(0, w-nw))
|
||||||
|
dy = int(self.rand(0, h-nh))
|
||||||
|
new_image = Image.new('RGB', (w,h), (128,128,128))
|
||||||
|
new_label = Image.new('L', (w,h), (0))
|
||||||
|
new_image.paste(image, (dx, dy))
|
||||||
|
new_label.paste(label, (dx, dy))
|
||||||
|
image = new_image
|
||||||
|
label = new_label
|
||||||
|
|
||||||
|
image_data = np.array(image, np.uint8)
|
||||||
|
#---------------------------------#
|
||||||
|
# 对图像进行色域变换
|
||||||
|
# 计算色域变换的参数
|
||||||
|
#---------------------------------#
|
||||||
|
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
|
||||||
|
#---------------------------------#
|
||||||
|
# 将图像转到HSV上
|
||||||
|
#---------------------------------#
|
||||||
|
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
|
||||||
|
dtype = image_data.dtype
|
||||||
|
#---------------------------------#
|
||||||
|
# 应用变换
|
||||||
|
#---------------------------------#
|
||||||
|
x = np.arange(0, 256, dtype=r.dtype)
|
||||||
|
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
||||||
|
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
||||||
|
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
||||||
|
|
||||||
|
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
||||||
|
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
|
||||||
|
|
||||||
|
return image_data, label
|
||||||
|
|
||||||
|
# DataLoader中collate_fn使用
|
||||||
|
def unet_dataset_collate(batch):
|
||||||
|
images = []
|
||||||
|
pngs = []
|
||||||
|
seg_labels = []
|
||||||
|
for img, png, labels in batch:
|
||||||
|
images.append(img)
|
||||||
|
pngs.append(png)
|
||||||
|
seg_labels.append(labels)
|
||||||
|
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
|
||||||
|
pngs = torch.from_numpy(np.array(pngs)).long()
|
||||||
|
seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
|
||||||
|
return images, pngs, seg_labels
|
|
@ -0,0 +1,150 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
|
||||||
|
from utils.utils import cvtColor, preprocess_input
|
||||||
|
|
||||||
|
|
||||||
|
class UnetDataset(Dataset):
|
||||||
|
def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
|
||||||
|
super(UnetDataset, self).__init__()
|
||||||
|
self.annotation_lines = annotation_lines
|
||||||
|
self.length = len(annotation_lines)
|
||||||
|
self.input_shape = input_shape
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.train = train
|
||||||
|
self.dataset_path = dataset_path
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
annotation_line = self.annotation_lines[index]
|
||||||
|
name = annotation_line.split()[0]
|
||||||
|
|
||||||
|
#-------------------------------#
|
||||||
|
# 从文件中读取图像
|
||||||
|
#-------------------------------#
|
||||||
|
jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "Images"), name + ".png"))
|
||||||
|
png = Image.open(os.path.join(os.path.join(self.dataset_path, "Labels"), name + ".png"))
|
||||||
|
#-------------------------------#
|
||||||
|
# 数据增强
|
||||||
|
#-------------------------------#
|
||||||
|
jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)
|
||||||
|
|
||||||
|
jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
|
||||||
|
png = np.array(png)
|
||||||
|
#-------------------------------------------------------#
|
||||||
|
# 这里的标签处理方式和普通voc的处理方式不同
|
||||||
|
# 将小于127.5的像素点设置为目标像素点。
|
||||||
|
#-------------------------------------------------------#
|
||||||
|
modify_png = np.zeros_like(png)
|
||||||
|
modify_png[png <= 127.5] = 1
|
||||||
|
seg_labels = modify_png
|
||||||
|
seg_labels = np.eye(self.num_classes + 1)[seg_labels.reshape([-1])]
|
||||||
|
seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
|
||||||
|
|
||||||
|
return jpg, modify_png, seg_labels
|
||||||
|
|
||||||
|
def rand(self, a=0, b=1):
|
||||||
|
return np.random.rand() * (b - a) + a
|
||||||
|
|
||||||
|
def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True):
|
||||||
|
image = cvtColor(image)
|
||||||
|
label = Image.fromarray(np.array(label))
|
||||||
|
#------------------------------#
|
||||||
|
# 获得图像的高宽与目标高宽
|
||||||
|
#------------------------------#
|
||||||
|
iw, ih = image.size
|
||||||
|
h, w = input_shape
|
||||||
|
|
||||||
|
if not random:
|
||||||
|
iw, ih = image.size
|
||||||
|
scale = min(w/iw, h/ih)
|
||||||
|
nw = int(iw*scale)
|
||||||
|
nh = int(ih*scale)
|
||||||
|
|
||||||
|
image = image.resize((nw,nh), Image.BICUBIC)
|
||||||
|
new_image = Image.new('RGB', [w, h], (128,128,128))
|
||||||
|
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
|
||||||
|
|
||||||
|
label = label.resize((nw,nh), Image.NEAREST)
|
||||||
|
new_label = Image.new('L', [w, h], (0))
|
||||||
|
new_label.paste(label, ((w-nw)//2, (h-nh)//2))
|
||||||
|
return new_image, new_label
|
||||||
|
|
||||||
|
#------------------------------------------#
|
||||||
|
# 对图像进行缩放并且进行长和宽的扭曲
|
||||||
|
#------------------------------------------#
|
||||||
|
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
|
||||||
|
scale = self.rand(0.25, 2)
|
||||||
|
if new_ar < 1:
|
||||||
|
nh = int(scale*h)
|
||||||
|
nw = int(nh*new_ar)
|
||||||
|
else:
|
||||||
|
nw = int(scale*w)
|
||||||
|
nh = int(nw/new_ar)
|
||||||
|
image = image.resize((nw,nh), Image.BICUBIC)
|
||||||
|
label = label.resize((nw,nh), Image.NEAREST)
|
||||||
|
|
||||||
|
#------------------------------------------#
|
||||||
|
# 翻转图像
|
||||||
|
#------------------------------------------#
|
||||||
|
flip = self.rand()<.5
|
||||||
|
if flip:
|
||||||
|
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
label = label.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
|
||||||
|
#------------------------------------------#
|
||||||
|
# 将图像多余的部分加上灰条
|
||||||
|
#------------------------------------------#
|
||||||
|
dx = int(self.rand(0, w-nw))
|
||||||
|
dy = int(self.rand(0, h-nh))
|
||||||
|
new_image = Image.new('RGB', (w,h), (128,128,128))
|
||||||
|
new_label = Image.new('L', (w,h), (0))
|
||||||
|
new_image.paste(image, (dx, dy))
|
||||||
|
new_label.paste(label, (dx, dy))
|
||||||
|
image = new_image
|
||||||
|
label = new_label
|
||||||
|
|
||||||
|
image_data = np.array(image, np.uint8)
|
||||||
|
#---------------------------------#
|
||||||
|
# 对图像进行色域变换
|
||||||
|
# 计算色域变换的参数
|
||||||
|
#---------------------------------#
|
||||||
|
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
|
||||||
|
#---------------------------------#
|
||||||
|
# 将图像转到HSV上
|
||||||
|
#---------------------------------#
|
||||||
|
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
|
||||||
|
dtype = image_data.dtype
|
||||||
|
#---------------------------------#
|
||||||
|
# 应用变换
|
||||||
|
#---------------------------------#
|
||||||
|
x = np.arange(0, 256, dtype=r.dtype)
|
||||||
|
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
||||||
|
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
||||||
|
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
||||||
|
|
||||||
|
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
||||||
|
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
|
||||||
|
|
||||||
|
return image_data, label
|
||||||
|
|
||||||
|
# DataLoader中collate_fn使用
|
||||||
|
def unet_dataset_collate(batch):
|
||||||
|
images = []
|
||||||
|
pngs = []
|
||||||
|
seg_labels = []
|
||||||
|
for img, png, labels in batch:
|
||||||
|
images.append(img)
|
||||||
|
pngs.append(png)
|
||||||
|
seg_labels.append(labels)
|
||||||
|
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
|
||||||
|
pngs = torch.from_numpy(np.array(pngs)).long()
|
||||||
|
seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
|
||||||
|
return images, pngs, seg_labels
|
|
@ -0,0 +1,76 @@
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def cvtColor(image):
|
||||||
|
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
||||||
|
return image
|
||||||
|
else:
|
||||||
|
image = image.convert('RGB')
|
||||||
|
return image
|
||||||
|
|
||||||
|
def resize_image(image, size):
|
||||||
|
iw, ih = image.size
|
||||||
|
w, h = size
|
||||||
|
|
||||||
|
scale = min(w/iw, h/ih)
|
||||||
|
nw = int(iw*scale)
|
||||||
|
nh = int(ih*scale)
|
||||||
|
|
||||||
|
image = image.resize((nw,nh), Image.BICUBIC)
|
||||||
|
new_image = Image.new('RGB', size, (128,128,128))
|
||||||
|
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
|
||||||
|
|
||||||
|
return new_image, nw, nh
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr(optimizer):
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
return param_group['lr']
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed=11):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
def worker_init_fn(worker_id, rank, seed):
|
||||||
|
worker_seed = rank + seed
|
||||||
|
random.seed(worker_seed)
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
torch.manual_seed(worker_seed)
|
||||||
|
|
||||||
|
def preprocess_input(image):
|
||||||
|
image /= 255.0
|
||||||
|
return image
|
||||||
|
|
||||||
|
def show_config(**kwargs):
|
||||||
|
print('Configurations:')
|
||||||
|
print('-' * 70)
|
||||||
|
print('|%25s | %40s|' % ('keys', 'values'))
|
||||||
|
print('-' * 70)
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
print('|%25s | %40s|' % (str(key), str(value)))
|
||||||
|
print('-' * 70)
|
||||||
|
|
||||||
|
def download_weights(backbone, model_dir="./model_data"):
|
||||||
|
import os
|
||||||
|
from torch.hub import load_state_dict_from_url
|
||||||
|
|
||||||
|
download_urls = {
|
||||||
|
'vgg' : 'https://download.pytorch.org/models/vgg16-397923af.pth',
|
||||||
|
'resnet50' : 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth'
|
||||||
|
}
|
||||||
|
url = download_urls[backbone]
|
||||||
|
|
||||||
|
if not os.path.exists(model_dir):
|
||||||
|
os.makedirs(model_dir)
|
||||||
|
load_state_dict_from_url(url, model_dir)
|
|
@ -0,0 +1,272 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from nets.unet_training import CE_Loss, Dice_loss, Focal_Loss
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from utils.utils import get_lr
|
||||||
|
from utils.utils_metrics import f_score
|
||||||
|
|
||||||
|
|
||||||
|
def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0):
|
||||||
|
total_loss = 0
|
||||||
|
total_f_score = 0
|
||||||
|
|
||||||
|
val_loss = 0
|
||||||
|
val_f_score = 0
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
print('Start Train')
|
||||||
|
pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
|
||||||
|
model_train.train()
|
||||||
|
for iteration, batch in enumerate(gen):
|
||||||
|
if iteration >= epoch_step:
|
||||||
|
break
|
||||||
|
imgs, pngs, labels = batch
|
||||||
|
with torch.no_grad():
|
||||||
|
weights = torch.from_numpy(cls_weights)
|
||||||
|
if cuda:
|
||||||
|
imgs = imgs.cuda(local_rank)
|
||||||
|
pngs = pngs.cuda(local_rank)
|
||||||
|
labels = labels.cuda(local_rank)
|
||||||
|
weights = weights.cuda(local_rank)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
if not fp16:
|
||||||
|
#----------------------#
|
||||||
|
# 前向传播
|
||||||
|
#----------------------#
|
||||||
|
outputs = model_train(imgs)
|
||||||
|
#----------------------#
|
||||||
|
# 损失计算
|
||||||
|
#----------------------#
|
||||||
|
if focal_loss:
|
||||||
|
loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
|
||||||
|
else:
|
||||||
|
loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
|
||||||
|
|
||||||
|
if dice_loss:
|
||||||
|
main_dice = Dice_loss(outputs, labels)
|
||||||
|
loss = loss + main_dice
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
#-------------------------------#
|
||||||
|
# 计算f_score
|
||||||
|
#-------------------------------#
|
||||||
|
_f_score = f_score(outputs, labels)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
else:
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
with autocast():
|
||||||
|
#----------------------#
|
||||||
|
# 前向传播
|
||||||
|
#----------------------#
|
||||||
|
outputs = model_train(imgs)
|
||||||
|
#----------------------#
|
||||||
|
# 损失计算
|
||||||
|
#----------------------#
|
||||||
|
if focal_loss:
|
||||||
|
loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
|
||||||
|
else:
|
||||||
|
loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
|
||||||
|
|
||||||
|
if dice_loss:
|
||||||
|
main_dice = Dice_loss(outputs, labels)
|
||||||
|
loss = loss + main_dice
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
#-------------------------------#
|
||||||
|
# 计算f_score
|
||||||
|
#-------------------------------#
|
||||||
|
_f_score = f_score(outputs, labels)
|
||||||
|
|
||||||
|
#----------------------#
|
||||||
|
# 反向传播
|
||||||
|
#----------------------#
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
total_f_score += _f_score.item()
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1),
|
||||||
|
'f_score' : total_f_score / (iteration + 1),
|
||||||
|
'lr' : get_lr(optimizer)})
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
pbar.close()
|
||||||
|
print('Finish Train')
|
||||||
|
print('Start Validation')
|
||||||
|
pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
|
||||||
|
|
||||||
|
model_train.eval()
|
||||||
|
for iteration, batch in enumerate(gen_val):
|
||||||
|
if iteration >= epoch_step_val:
|
||||||
|
break
|
||||||
|
imgs, pngs, labels = batch
|
||||||
|
with torch.no_grad():
|
||||||
|
weights = torch.from_numpy(cls_weights)
|
||||||
|
if cuda:
|
||||||
|
imgs = imgs.cuda(local_rank)
|
||||||
|
pngs = pngs.cuda(local_rank)
|
||||||
|
labels = labels.cuda(local_rank)
|
||||||
|
weights = weights.cuda(local_rank)
|
||||||
|
|
||||||
|
#----------------------#
|
||||||
|
# 前向传播
|
||||||
|
#----------------------#
|
||||||
|
outputs = model_train(imgs)
|
||||||
|
#----------------------#
|
||||||
|
# 损失计算
|
||||||
|
#----------------------#
|
||||||
|
if focal_loss:
|
||||||
|
loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
|
||||||
|
else:
|
||||||
|
loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
|
||||||
|
|
||||||
|
if dice_loss:
|
||||||
|
main_dice = Dice_loss(outputs, labels)
|
||||||
|
loss = loss + main_dice
|
||||||
|
#-------------------------------#
|
||||||
|
# 计算f_score
|
||||||
|
#-------------------------------#
|
||||||
|
_f_score = f_score(outputs, labels)
|
||||||
|
|
||||||
|
val_loss += loss.item()
|
||||||
|
val_f_score += _f_score.item()
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1),
|
||||||
|
'f_score' : val_f_score / (iteration + 1),
|
||||||
|
'lr' : get_lr(optimizer)})
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
pbar.close()
|
||||||
|
print('Finish Validation')
|
||||||
|
loss_history.append_loss(epoch + 1, total_loss/ epoch_step, val_loss/ epoch_step_val)
|
||||||
|
eval_callback.on_epoch_end(epoch + 1, model_train)
|
||||||
|
print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
|
||||||
|
print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
|
||||||
|
|
||||||
|
#-----------------------------------------------#
|
||||||
|
# 保存权值
|
||||||
|
#-----------------------------------------------#
|
||||||
|
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
|
||||||
|
torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth'%((epoch + 1), total_loss / epoch_step, val_loss / epoch_step_val)))
|
||||||
|
|
||||||
|
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
|
||||||
|
print('Save best model to best_epoch_weights.pth')
|
||||||
|
torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))
|
||||||
|
|
||||||
|
def fit_one_epoch_no_val(model_train, model, loss_history, optimizer, epoch, epoch_step, gen, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, fp16, scaler, save_period, save_dir, local_rank=0):
|
||||||
|
total_loss = 0
|
||||||
|
total_f_score = 0
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
print('Start Train')
|
||||||
|
pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
|
||||||
|
model_train.train()
|
||||||
|
for iteration, batch in enumerate(gen):
|
||||||
|
if iteration >= epoch_step:
|
||||||
|
break
|
||||||
|
imgs, pngs, labels = batch
|
||||||
|
with torch.no_grad():
|
||||||
|
weights = torch.from_numpy(cls_weights)
|
||||||
|
if cuda:
|
||||||
|
imgs = imgs.cuda(local_rank)
|
||||||
|
pngs = pngs.cuda(local_rank)
|
||||||
|
labels = labels.cuda(local_rank)
|
||||||
|
weights = weights.cuda(local_rank)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
if not fp16:
|
||||||
|
#----------------------#
|
||||||
|
# 前向传播
|
||||||
|
#----------------------#
|
||||||
|
outputs = model_train(imgs)
|
||||||
|
#----------------------#
|
||||||
|
# 损失计算
|
||||||
|
#----------------------#
|
||||||
|
if focal_loss:
|
||||||
|
loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
|
||||||
|
else:
|
||||||
|
loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
|
||||||
|
|
||||||
|
if dice_loss:
|
||||||
|
main_dice = Dice_loss(outputs, labels)
|
||||||
|
loss = loss + main_dice
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
#-------------------------------#
|
||||||
|
# 计算f_score
|
||||||
|
#-------------------------------#
|
||||||
|
_f_score = f_score(outputs, labels)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
else:
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
with autocast():
|
||||||
|
#----------------------#
|
||||||
|
# 前向传播
|
||||||
|
#----------------------#
|
||||||
|
outputs = model_train(imgs)
|
||||||
|
#----------------------#
|
||||||
|
# 损失计算
|
||||||
|
#----------------------#
|
||||||
|
if focal_loss:
|
||||||
|
loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
|
||||||
|
else:
|
||||||
|
loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
|
||||||
|
|
||||||
|
if dice_loss:
|
||||||
|
main_dice = Dice_loss(outputs, labels)
|
||||||
|
loss = loss + main_dice
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
#-------------------------------#
|
||||||
|
# 计算f_score
|
||||||
|
#-------------------------------#
|
||||||
|
_f_score = f_score(outputs, labels)
|
||||||
|
|
||||||
|
#----------------------#
|
||||||
|
# 反向传播
|
||||||
|
#----------------------#
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
total_f_score += _f_score.item()
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1),
|
||||||
|
'f_score' : total_f_score / (iteration + 1),
|
||||||
|
'lr' : get_lr(optimizer)})
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
pbar.close()
|
||||||
|
loss_history.append_loss(epoch + 1, total_loss/ epoch_step)
|
||||||
|
print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
|
||||||
|
print('Total Loss: %.3f' % (total_loss / epoch_step))
|
||||||
|
|
||||||
|
#-----------------------------------------------#
|
||||||
|
# 保存权值
|
||||||
|
#-----------------------------------------------#
|
||||||
|
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
|
||||||
|
torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f.pth'%((epoch + 1), total_loss / epoch_step)))
|
||||||
|
|
||||||
|
if len(loss_history.losses) <= 1 or (total_loss / epoch_step) <= min(loss_history.losses):
|
||||||
|
print('Save best model to best_epoch_weights.pth')
|
||||||
|
torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))
|
|
@ -0,0 +1,182 @@
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
from os.path import join
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5):
|
||||||
|
n, c, h, w = inputs.size()
|
||||||
|
nt, ht, wt, ct = target.size()
|
||||||
|
if h != ht and w != wt:
|
||||||
|
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
|
||||||
|
|
||||||
|
temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
|
||||||
|
temp_target = target.view(n, -1, ct)
|
||||||
|
|
||||||
|
#--------------------------------------------#
|
||||||
|
# 计算dice系数
|
||||||
|
#--------------------------------------------#
|
||||||
|
temp_inputs = torch.gt(temp_inputs, threhold).float()
|
||||||
|
tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
|
||||||
|
fp = torch.sum(temp_inputs , axis=[0,1]) - tp
|
||||||
|
fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp
|
||||||
|
|
||||||
|
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
|
||||||
|
score = torch.mean(score)
|
||||||
|
return score
|
||||||
|
|
||||||
|
# 设标签宽W,长H
|
||||||
|
def fast_hist(a, b, n):
|
||||||
|
#--------------------------------------------------------------------------------#
|
||||||
|
# a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,)
|
||||||
|
#--------------------------------------------------------------------------------#
|
||||||
|
k = (a >= 0) & (a < n)
|
||||||
|
#--------------------------------------------------------------------------------#
|
||||||
|
# np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
|
||||||
|
# 返回中,写对角线上的为分类正确的像素点
|
||||||
|
#--------------------------------------------------------------------------------#
|
||||||
|
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
|
||||||
|
|
||||||
|
def per_class_iu(hist):
|
||||||
|
return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1)
|
||||||
|
|
||||||
|
def per_class_PA_Recall(hist):
|
||||||
|
return np.diag(hist) / np.maximum(hist.sum(1), 1)
|
||||||
|
|
||||||
|
def per_class_Precision(hist):
|
||||||
|
return np.diag(hist) / np.maximum(hist.sum(0), 1)
|
||||||
|
|
||||||
|
def per_Accuracy(hist):
|
||||||
|
return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1)
|
||||||
|
|
||||||
|
def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None):
|
||||||
|
print('Num classes', num_classes)
|
||||||
|
#-----------------------------------------#
|
||||||
|
# 创建一个全是0的矩阵,是一个混淆矩阵
|
||||||
|
#-----------------------------------------#
|
||||||
|
hist = np.zeros((num_classes, num_classes))
|
||||||
|
|
||||||
|
#------------------------------------------------#
|
||||||
|
# 获得验证集标签路径列表,方便直接读取
|
||||||
|
# 获得验证集图像分割结果路径列表,方便直接读取
|
||||||
|
#------------------------------------------------#
|
||||||
|
gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list]
|
||||||
|
pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list]
|
||||||
|
|
||||||
|
#------------------------------------------------#
|
||||||
|
# 读取每一个(图片-标签)对
|
||||||
|
#------------------------------------------------#
|
||||||
|
for ind in range(len(gt_imgs)):
|
||||||
|
#------------------------------------------------#
|
||||||
|
# 读取一张图像分割结果,转化成numpy数组
|
||||||
|
#------------------------------------------------#
|
||||||
|
pred = np.array(Image.open(pred_imgs[ind]))
|
||||||
|
#------------------------------------------------#
|
||||||
|
# 读取一张对应的标签,转化成numpy数组
|
||||||
|
#------------------------------------------------#
|
||||||
|
label = np.array(Image.open(gt_imgs[ind]))
|
||||||
|
|
||||||
|
# 如果图像分割结果与标签的大小不一样,这张图片就不计算
|
||||||
|
if len(label.flatten()) != len(pred.flatten()):
|
||||||
|
print(
|
||||||
|
'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(
|
||||||
|
len(label.flatten()), len(pred.flatten()), gt_imgs[ind],
|
||||||
|
pred_imgs[ind]))
|
||||||
|
continue
|
||||||
|
|
||||||
|
#------------------------------------------------#
|
||||||
|
# 对一张图片计算21×21的hist矩阵,并累加
|
||||||
|
#------------------------------------------------#
|
||||||
|
hist += fast_hist(label.flatten(), pred.flatten(), num_classes)
|
||||||
|
# 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值
|
||||||
|
if name_classes is not None and ind > 0 and ind % 10 == 0:
|
||||||
|
print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format(
|
||||||
|
ind,
|
||||||
|
len(gt_imgs),
|
||||||
|
100 * np.nanmean(per_class_iu(hist)),
|
||||||
|
100 * np.nanmean(per_class_PA_Recall(hist)),
|
||||||
|
100 * per_Accuracy(hist)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
#------------------------------------------------#
|
||||||
|
# 计算所有验证集图片的逐类别mIoU值
|
||||||
|
#------------------------------------------------#
|
||||||
|
IoUs = per_class_iu(hist)
|
||||||
|
PA_Recall = per_class_PA_Recall(hist)
|
||||||
|
Precision = per_class_Precision(hist)
|
||||||
|
#------------------------------------------------#
|
||||||
|
# 逐类别输出一下mIoU值
|
||||||
|
#------------------------------------------------#
|
||||||
|
if name_classes is not None:
|
||||||
|
for ind_class in range(num_classes):
|
||||||
|
print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \
|
||||||
|
+ '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2)))
|
||||||
|
|
||||||
|
#-----------------------------------------------------------------#
|
||||||
|
# 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值
|
||||||
|
#-----------------------------------------------------------------#
|
||||||
|
print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2)))
|
||||||
|
return np.array(hist, np.int), IoUs, PA_Recall, Precision
|
||||||
|
|
||||||
|
def adjust_axes(r, t, fig, axes):
|
||||||
|
bb = t.get_window_extent(renderer=r)
|
||||||
|
text_width_inches = bb.width / fig.dpi
|
||||||
|
current_fig_width = fig.get_figwidth()
|
||||||
|
new_fig_width = current_fig_width + text_width_inches
|
||||||
|
propotion = new_fig_width / current_fig_width
|
||||||
|
x_lim = axes.get_xlim()
|
||||||
|
axes.set_xlim([x_lim[0], x_lim[1] * propotion])
|
||||||
|
|
||||||
|
def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True):
|
||||||
|
fig = plt.gcf()
|
||||||
|
axes = plt.gca()
|
||||||
|
plt.barh(range(len(values)), values, color='royalblue')
|
||||||
|
plt.title(plot_title, fontsize=tick_font_size + 2)
|
||||||
|
plt.xlabel(x_label, fontsize=tick_font_size)
|
||||||
|
plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size)
|
||||||
|
r = fig.canvas.get_renderer()
|
||||||
|
for i, val in enumerate(values):
|
||||||
|
str_val = " " + str(val)
|
||||||
|
if val < 1.0:
|
||||||
|
str_val = " {0:.2f}".format(val)
|
||||||
|
t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold')
|
||||||
|
if i == (len(values)-1):
|
||||||
|
adjust_axes(r, t, fig, axes)
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(output_path)
|
||||||
|
if plt_show:
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12):
|
||||||
|
draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \
|
||||||
|
os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True)
|
||||||
|
print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png"))
|
||||||
|
|
||||||
|
draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Pixel Accuracy", \
|
||||||
|
os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False)
|
||||||
|
print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png"))
|
||||||
|
|
||||||
|
draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Recall", \
|
||||||
|
os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False)
|
||||||
|
print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png"))
|
||||||
|
|
||||||
|
draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \
|
||||||
|
os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False)
|
||||||
|
print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png"))
|
||||||
|
|
||||||
|
with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer_list = []
|
||||||
|
writer_list.append([' '] + [str(c) for c in name_classes])
|
||||||
|
for i in range(len(hist)):
|
||||||
|
writer_list.append([name_classes[i]] + [str(x) for x in hist[i]])
|
||||||
|
writer.writerows(writer_list)
|
||||||
|
print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv"))
|
||||||
|
|
|
@ -0,0 +1,191 @@
|
||||||
|
import os
|
||||||
|
import streamlit as st
|
||||||
|
import cv2
|
||||||
|
import tempfile
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL.Image import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
from unet import Unet
|
||||||
|
from nets.U_ConvAutoencoder import U_ConvAutoencoder
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
|
# Constants and configuration
|
||||||
|
DEFAULTS = {
|
||||||
|
"model_path": 'model_data/8414_8376.pth',
|
||||||
|
"num_classes": 2,
|
||||||
|
"backbone": "vgg",
|
||||||
|
"input_shape": [1696, 864],
|
||||||
|
"mix_type": 1,
|
||||||
|
"cuda": torch.cuda.is_available(),
|
||||||
|
}
|
||||||
|
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
TRANSFORM = transforms.Compose([
|
||||||
|
transforms.Resize((1728, 3392)),
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class PreCA:
|
||||||
|
model: U_ConvAutoencoder = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize_model(cls, u_ca_path: str) -> None:
|
||||||
|
cls.model = U_ConvAutoencoder().to(DEVICE)
|
||||||
|
cls.model.load_state_dict(torch.load(u_ca_path, map_location=DEVICE))
|
||||||
|
cls.model.eval()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unload_model(cls) -> None:
|
||||||
|
cls.model = None
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_image(cls, image: Image.Image) -> torch.Tensor:
|
||||||
|
image = image.convert("L")
|
||||||
|
image = TRANSFORM(image).unsqueeze(0)
|
||||||
|
return image.to(DEVICE)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ca_smooth(image: Image.Image) -> Image.Image:
|
||||||
|
image_cv2 = np.array(image)
|
||||||
|
closed_image = cv2.morphologyEx(image_cv2, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)))
|
||||||
|
blurred = cv2.GaussianBlur(closed_image, (1, 1), 0)
|
||||||
|
th = cv2.threshold(blurred, 126, 255, cv2.THRESH_BINARY)[1]
|
||||||
|
return Image.fromarray(th)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def infer(cls, image: Image.Image) -> Image.Image:
|
||||||
|
image_tensor = cls.load_image(image)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = cls.model(image_tensor)
|
||||||
|
output = output.squeeze(0).cpu()
|
||||||
|
output_image = transforms.ToPILImage()(output)
|
||||||
|
return output_image.resize((3384, 1710), Image.NEAREST)
|
||||||
|
|
||||||
|
|
||||||
|
class PreUnet:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_metrics(pred_image: Image.Image, true_image: Image.Image, threshold: int = 1) -> Tuple[int, int, int]:
|
||||||
|
pred_binary = pred_image.convert('L').point(lambda x: 0 if x < threshold else 255)
|
||||||
|
true_binary = true_image.convert('L').point(lambda x: 0 if x < threshold else 255)
|
||||||
|
|
||||||
|
pred_array = np.array(pred_binary)
|
||||||
|
true_array = np.array(true_binary)
|
||||||
|
|
||||||
|
TP = np.sum((pred_array == 255) & (true_array == 255))
|
||||||
|
FP = np.sum((pred_array == 255) & (true_array == 0))
|
||||||
|
FN = np.sum((pred_array == 0) & (true_array == 255))
|
||||||
|
|
||||||
|
return TP, FP, FN
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_mask(original_image: Image.Image, mask_image: Image.Image) -> Image.Image:
|
||||||
|
original_image = original_image.convert("RGB").resize((3384, 1710), Image.NEAREST)
|
||||||
|
mask_image = mask_image.convert("RGB").resize((3384, 1710), Image.NEAREST)
|
||||||
|
|
||||||
|
original_array = np.array(original_image)
|
||||||
|
mask_array = np.array(mask_image)
|
||||||
|
|
||||||
|
mask = np.all(mask_array == [255, 255, 255], axis=-1)
|
||||||
|
original_array[mask] = [0, 255, 0]
|
||||||
|
|
||||||
|
return Image.fromarray(original_array)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def process_image(cls, image: Image.Image, unet):
|
||||||
|
detected_image = unet.detect_image(image)
|
||||||
|
inferred_image = PreCA.infer(detected_image)
|
||||||
|
smoothed_image = PreCA.ca_smooth(inferred_image)
|
||||||
|
return cls.apply_mask(image, smoothed_image),smoothed_image
|
||||||
|
|
||||||
|
|
||||||
|
def main_page():
|
||||||
|
st.title('自动驾驶车道线自动检测与增强')
|
||||||
|
stframe = st.empty()
|
||||||
|
st.sidebar.subheader("参数设置")
|
||||||
|
|
||||||
|
is_pre = st.sidebar.checkbox('开启预测')
|
||||||
|
unet = Unet(DEFAULTS) if is_pre else None
|
||||||
|
|
||||||
|
if is_pre:
|
||||||
|
u_ca_path = 'weights/best_conv_autoencoder1.pth'
|
||||||
|
PreCA.initialize_model(u_ca_path)
|
||||||
|
else:
|
||||||
|
PreCA.unload_model()
|
||||||
|
|
||||||
|
st.sidebar.subheader("图像检测")
|
||||||
|
image_dir_path = st.sidebar.text_input('请输入图像文件夹路径:')
|
||||||
|
is_get_iou = st.sidebar.checkbox('开启计算IOU')
|
||||||
|
label_dir_path = st.sidebar.text_input('请输入标签文件夹路径:') if is_get_iou else None
|
||||||
|
btn_click = st.sidebar.button("开始预测")
|
||||||
|
|
||||||
|
if btn_click:
|
||||||
|
process_images(image_dir_path, label_dir_path, unet, is_pre, is_get_iou, stframe)
|
||||||
|
|
||||||
|
st.sidebar.subheader("视频检测")
|
||||||
|
uploaded_video = st.sidebar.file_uploader("上传视频:", type=['mp4', 'mov', 'avi', 'mkv', 'flv', 'wmv', 'm4v'])
|
||||||
|
|
||||||
|
if uploaded_video is not None:
|
||||||
|
process_video(uploaded_video, unet, is_pre, stframe)
|
||||||
|
|
||||||
|
|
||||||
|
def process_images(image_dir_path, label_dir_path, unet, is_pre, is_get_iou, stframe):
|
||||||
|
ious = []
|
||||||
|
img_names = os.listdir(image_dir_path)
|
||||||
|
iou_text = st.empty()
|
||||||
|
|
||||||
|
for img_name in img_names:
|
||||||
|
if img_name.lower().endswith(
|
||||||
|
('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
|
||||||
|
image_path = os.path.join(image_dir_path, img_name)
|
||||||
|
image = Image.open(image_path)
|
||||||
|
if is_pre:
|
||||||
|
result_image,smoothed_image = PreUnet.process_image(image, unet)
|
||||||
|
stframe.image([image, result_image], width=640)
|
||||||
|
if is_get_iou and label_dir_path:
|
||||||
|
label_path = os.path.join(label_dir_path, f"{os.path.splitext(img_name)[0]}_bin.png")
|
||||||
|
label = Image.open(label_path)
|
||||||
|
TP, FP, FN = PreUnet.calculate_metrics(smoothed_image, label)
|
||||||
|
iou = TP / (TP + FP + FN)
|
||||||
|
# ious.append(iou)
|
||||||
|
iou_text.text(f'当前IOU: {iou}')
|
||||||
|
else:
|
||||||
|
stframe.image(image, width=1024)
|
||||||
|
|
||||||
|
|
||||||
|
def process_video(uploaded_video, unet, is_pre, stframe):
|
||||||
|
tfile = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
tfile.write(uploaded_video.read())
|
||||||
|
tfile.close()
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(tfile.name)
|
||||||
|
|
||||||
|
if 'frame_pos' not in st.session_state:
|
||||||
|
st.session_state.frame_pos = 0
|
||||||
|
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, st.session_state.frame_pos)
|
||||||
|
|
||||||
|
while cap.isOpened():
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
st.session_state.frame_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
|
||||||
|
if is_pre:
|
||||||
|
processed_frame,smoothed_image = PreUnet.process_image(frame, unet)
|
||||||
|
stframe.image(processed_frame, width=1024,use_column_width=False)
|
||||||
|
else:
|
||||||
|
stframe.image(frame, width=1024,use_column_width=False)
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main_page()
|
|
@ -0,0 +1,78 @@
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def remove_foreground_and_fill(image_path, mask_path, crop_size):
|
||||||
|
# 读取图像和掩模
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
mask = cv2.imread(mask_path, 0) # 假设掩模是灰度图
|
||||||
|
|
||||||
|
# 确保掩模是二值化的
|
||||||
|
_, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
|
||||||
|
|
||||||
|
# 随机生成裁剪区域的起始点
|
||||||
|
h, w = mask.shape
|
||||||
|
start_x = np.random.randint(0, w - crop_size[1] + 1)
|
||||||
|
start_y = np.random.randint(0, h - crop_size[0] + 1)
|
||||||
|
|
||||||
|
# 裁剪掩模和图像
|
||||||
|
cropped_mask = mask[start_y:start_y + crop_size[0], start_x:start_x + crop_size[1]]
|
||||||
|
cropped_image = image[start_y:start_y + crop_size[0], start_x:start_x + crop_size[1]]
|
||||||
|
|
||||||
|
# 使用inpaint方法进行前景消除并填补
|
||||||
|
result = cv2.inpaint(cropped_image, cropped_mask, 3, cv2.INPAINT_TELEA)
|
||||||
|
|
||||||
|
# 将填补后的图像放回原图的相应位置
|
||||||
|
result_image = image.copy()
|
||||||
|
result_image[start_y:start_y + crop_size[0], start_x:start_x + crop_size[1]] = result
|
||||||
|
|
||||||
|
return result_image
|
||||||
|
|
||||||
|
|
||||||
|
def process_image(img_name, img_folder, label_folder, save_dir):
|
||||||
|
if img_name.lower().endswith('.jpg'):
|
||||||
|
label_name = img_name[:-4] + '_bin.png' # 构造标签文件名
|
||||||
|
if label_name in os.listdir(label_folder): # 确保标签文件存在
|
||||||
|
# 调用remove_foreground_and_fill函数处理图像和掩模
|
||||||
|
result_image = remove_foreground_and_fill(
|
||||||
|
os.path.join(img_folder, img_name),
|
||||||
|
os.path.join(label_folder, label_name),
|
||||||
|
(500, 1710) # 定义裁剪尺寸为500x1710像素
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存处理后的图像到save_files文件夹
|
||||||
|
save_path = os.path.join(save_dir, img_name)
|
||||||
|
cv2.imwrite(save_path, result_image)
|
||||||
|
|
||||||
|
|
||||||
|
def write_img_label_txt(base_dir, dataset_type):
|
||||||
|
# 创建保存txt文件的目录
|
||||||
|
save_dir = os.path.join(base_dir, 'save_files')
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 获取img和label文件夹的路径
|
||||||
|
img_folder = os.path.join(base_dir, dataset_type, 'img')
|
||||||
|
label_folder = os.path.join(base_dir, dataset_type, 'label')
|
||||||
|
|
||||||
|
img_names = [img_name for img_name in os.listdir(img_folder) if img_name.lower().endswith('.jpg')]
|
||||||
|
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = []
|
||||||
|
for img_name in img_names:
|
||||||
|
futures.append(executor.submit(process_image, img_name, img_folder, label_folder, save_dir))
|
||||||
|
|
||||||
|
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing images"):
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
|
# 基础目录
|
||||||
|
base_dir = r'E:\git\unet_seg\unet\original_data\dataset_A'
|
||||||
|
|
||||||
|
# 处理test和train文件夹
|
||||||
|
for dataset_type in ['train']:
|
||||||
|
write_img_label_txt(base_dir, dataset_type)
|
||||||
|
|
||||||
|
print('All images have been processed and saved.')
|
Loading…
Reference in New Issue