import torch import torch.nn as nn from nets.resnet import resnet50 from nets.vgg import VGG16 class unetUp(nn.Module): def __init__(self, in_size, out_size): super(unetUp, self).__init__() self.conv1 = nn.Conv2d(in_size, out_size, kernel_size = 3, padding = 1) self.conv2 = nn.Conv2d(out_size, out_size, kernel_size = 3, padding = 1) self.up = nn.UpsamplingBilinear2d(scale_factor = 2) self.relu = nn.ReLU(inplace = True) def forward(self, inputs1, inputs2): outputs = torch.cat([inputs1, self.up(inputs2)], 1) outputs = self.conv1(outputs) outputs = self.relu(outputs) outputs = self.conv2(outputs) outputs = self.relu(outputs) return outputs class Unet(nn.Module): def __init__(self, num_classes = 21, pretrained = False, backbone = 'vgg'): super(Unet, self).__init__() if backbone == 'vgg': self.vgg = VGG16(pretrained = pretrained) in_filters = [192, 384, 768, 1024] elif backbone == "resnet50": self.resnet = resnet50(pretrained = pretrained) in_filters = [192, 512, 1024, 3072] else: raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone)) out_filters = [64, 128, 256, 512] # upsampling # 64,64,512 self.up_concat4 = unetUp(in_filters[3], out_filters[3]) # 128,128,256 self.up_concat3 = unetUp(in_filters[2], out_filters[2]) # 256,256,128 self.up_concat2 = unetUp(in_filters[1], out_filters[1]) # 512,512,64 self.up_concat1 = unetUp(in_filters[0], out_filters[0]) if backbone == 'resnet50': self.up_conv = nn.Sequential( nn.UpsamplingBilinear2d(scale_factor = 2), nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1), nn.ReLU(), nn.Conv2d(out_filters[0], out_filters[0], kernel_size = 3, padding = 1), nn.ReLU(), ) else: self.up_conv = None self.final = nn.Conv2d(out_filters[0], num_classes, 1) self.backbone = backbone def forward(self, inputs): if self.backbone == "vgg": [feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs) elif self.backbone == "resnet50": [feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs) up4 = self.up_concat4(feat4, feat5) up3 = self.up_concat3(feat3, up4) up2 = self.up_concat2(feat2, up3) up1 = self.up_concat1(feat1, up2) if self.up_conv != None: up1 = self.up_conv(up1) final = self.final(up1) return final def freeze_backbone(self): if self.backbone == "vgg": for param in self.vgg.parameters(): param.requires_grad = False elif self.backbone == "resnet50": for param in self.resnet.parameters(): param.requires_grad = False def unfreeze_backbone(self): if self.backbone == "vgg": for param in self.vgg.parameters(): param.requires_grad = True elif self.backbone == "resnet50": for param in self.resnet.parameters(): param.requires_grad = True