UNet_UAE_for_Lane_Detection/U-AE/train.py

199 lines
6.8 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
from tqdm import tqdm # 导入tqdm库
# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义卷积自编码器
class ConvAutoencoder(nn.Module):
def __init__(self):
super(ConvAutoencoder, self).__init__()
# 编码器
self.encoder1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # output: 16 x 1692 x 855
nn.BatchNorm2d(16),
nn.ReLU(True)
)
self.encoder2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # output: 32 x 846 x 428
nn.BatchNorm2d(32),
nn.ReLU(True)
)
self.encoder3 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # output: 64 x 423 x 214
nn.BatchNorm2d(64),
nn.ReLU(True)
)
self.encoder4 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # output: 128 x 212 x 107
nn.BatchNorm2d(128),
nn.ReLU(True)
)
self.encoder5 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # output: 256 x 106 x 54
nn.BatchNorm2d(256),
nn.ReLU(True)
)
# 解码器
self.decoder5 = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 128 x 212 x 107
nn.BatchNorm2d(128),
nn.ReLU(True)
)
self.decoder4 = nn.Sequential(
nn.ConvTranspose2d(256, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 64 x 423 x 214
nn.BatchNorm2d(64),
nn.ReLU(True)
)
self.decoder3 = nn.Sequential(
nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 32 x 846 x 428
nn.BatchNorm2d(32),
nn.ReLU(True)
)
self.decoder2 = nn.Sequential(
nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 16 x 1692 x 855
nn.BatchNorm2d(16),
nn.ReLU(True)
)
self.decoder1 = nn.Sequential(
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 1 x 3384 x 1710
nn.Sigmoid() # 使用Sigmoid以确保输出在[0, 1]范围内
)
def forward(self, x):
# 编码器
enc1 = self.encoder1(x)
enc2 = self.encoder2(enc1)
enc3 = self.encoder3(enc2)
enc4 = self.encoder4(enc3)
enc5 = self.encoder5(enc4)
# 解码器
dec5 = self.decoder5(enc5)
dec4 = self.decoder4(torch.cat([dec5, enc4], dim=1))
dec3 = self.decoder3(torch.cat([dec4, enc3], dim=1))
dec2 = self.decoder2(torch.cat([dec3, enc2], dim=1))
dec1 = self.decoder1(torch.cat([dec2, enc1], dim=1))
return dec1
# 自定义数据集加载器
class CustomImageDataset(Dataset):
def __init__(self, image_dir, label_dir, transform=None):
self.image_dir = image_dir
self.label_dir = label_dir
self.transform = transform
self.image_names = os.listdir(image_dir)
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
img_name = self.image_names[idx]
img_path = os.path.join(self.image_dir, img_name)
image = Image.open(img_path).convert("L")
label_name = img_name # 假设图像名与标签名匹配
label_path = os.path.join(self.label_dir, label_name)
label_image = Image.open(label_path).convert("L")
if self.transform:
image = self.transform(image)
label_image = self.transform(label_image)
return image, label_image
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# IoU计算函数
def compute_iou(pred, target, threshold=0.5):
pred = (pred > threshold).float()
target = (target > threshold).float()
intersection = (pred * target).sum()
union = pred.sum() + target.sum() - intersection
if union == 0:
return 1.0 if intersection == 0 else 0.0
iou = intersection / union
return iou.item()
# 图像预处理和数据加载
transform = transforms.Compose([
transforms.Resize((1728, 3392)),
transforms.ToTensor()
])
def train():
for epoch in range(num_epochs):
running_loss = 0.0
total_iou = 0.0
progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
for data in progress_bar:
imgs, label_imgs = data
imgs, label_imgs = imgs.to(device), label_imgs.to(device) # 将数据移动到GPU
# 前向传播
output = model(imgs)
loss = criterion(output, label_imgs)
# 计算IoU
iou = compute_iou(output, label_imgs)
total_iou += iou
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
# 更新进度条描述
progress_bar.set_postfix(loss=loss.item(), iou=iou)
# 打印损失和IoU
epoch_loss = running_loss / len(data_loader)
epoch_iou = total_iou / len(data_loader)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}, IoU: {epoch_iou}')
# 保存最佳IoU模型
if epoch_iou > best_iou:
best_iou = epoch_iou
torch.save(model.state_dict(), best_model_path)
print(f'Best model saved with IoU: {best_iou}')
# 每5个epoch保存一次模型权重
if (epoch + 1) % 5 == 0:
torch.save(model.state_dict(), f'out_weights/conv_autoencoder_epoch_{epoch+1}.pth')
print(f'Model weights saved at epoch {epoch+1}')
if __name__ == '__main__':
image_dir = './img' # 替换为你的图像文件夹路径
label_dir = './label' # 替换为你的标签文件夹路径
dataset = CustomImageDataset(image_dir, label_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=15, shuffle=True)
# 实例化模型、定义损失函数和优化器
model = ConvAutoencoder().to(device) # 将模型移动到GPU
print(f'The model has {count_parameters(model):,} trainable parameters')
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# 保存最佳IoU模型
best_iou = 0.0
best_model_path = 'out_weights/best_conv_autoencoder.pth'
# 训练卷积自编码器
num_epochs = 100
train()