UNet_UAE_for_Lane_Detection/nets/U_ConvAutoencoder.py

75 lines
2.8 KiB
Python
Raw Normal View History

2024-08-23 19:42:44 +08:00
import torch
import torch.nn as nn
class U_ConvAutoencoder(nn.Module):
def __init__(self):
super(U_ConvAutoencoder, self).__init__()
# 编码器
self.encoder1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # output: 16 x 1692 x 855
nn.BatchNorm2d(16),
nn.ReLU(True)
)
self.encoder2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # output: 32 x 846 x 428
nn.BatchNorm2d(32),
nn.ReLU(True)
)
self.encoder3 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # output: 64 x 423 x 214
nn.BatchNorm2d(64),
nn.ReLU(True)
)
self.encoder4 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # output: 128 x 212 x 107
nn.BatchNorm2d(128),
nn.ReLU(True)
)
self.encoder5 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # output: 256 x 106 x 54
nn.BatchNorm2d(256),
nn.ReLU(True)
)
# 解码器
self.decoder5 = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
# output: 128 x 212 x 107
nn.BatchNorm2d(128),
nn.ReLU(True)
)
self.decoder4 = nn.Sequential(
nn.ConvTranspose2d(256, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 64 x 423 x 214
nn.BatchNorm2d(64),
nn.ReLU(True)
)
self.decoder3 = nn.Sequential(
nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 32 x 846 x 428
nn.BatchNorm2d(32),
nn.ReLU(True)
)
self.decoder2 = nn.Sequential(
nn.ConvTranspose2d(64, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 16 x 1692 x 855
nn.BatchNorm2d(16),
nn.ReLU(True)
)
self.decoder1 = nn.Sequential(
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # output: 1 x 3384 x 1710
nn.Sigmoid() # 使用Sigmoid以确保输出在[0, 1]范围内
)
def forward(self, x):
# 编码器
enc1 = self.encoder1(x)
enc2 = self.encoder2(enc1)
enc3 = self.encoder3(enc2)
enc4 = self.encoder4(enc3)
enc5 = self.encoder5(enc4)
# 解码器
dec5 = self.decoder5(enc5)
dec4 = self.decoder4(torch.cat([dec5, enc4], dim=1))
dec3 = self.decoder3(torch.cat([dec4, enc3], dim=1))
dec2 = self.decoder2(torch.cat([dec3, enc2], dim=1))
dec1 = self.decoder1(torch.cat([dec2, enc1], dim=1))
return dec1