Unet-语义分割


1.何为语义分割?

语义分割结合了目标检测、图像分类和图像分割等技术。图片输入,通过语义分割模型对原有图像分割成具有一定语义含义的区域块,识别出每个区域块语义类别,最终得到与原图像等大小具有逐像素语义标注的分割图像。

四幅图分别代表(a)目标分类,(b)识别与定位,(c)语义分割,(d)实例分割

语义分割实则就是把整张图片中的像素点分类。

原图像                              标签                

             

由图像可知,共分为了四类:1.花瓣(红色),2.叶子(黄色),3.茎秆(绿色),4.背景(黑色)

一般情况下,数据集图片大小可能不一,但导入网络的数据集需要大小相同。此时需要对数据集进行图片切割预处理。将数据集图片等大小切割。

(跑一边数据集,找到其中其中最小的W,H,选取切割大小 w<=W,h<=H)

 

 2.Unet

unet对称语义分割模型

 

 主要由三部分组成。

第一部分主干特征提取(即灰线),利用主干部分获得一个又一个特征层,其主干特征提取部分与VGG16类似,利用此部分共获得5个初步有效特征层,用于下一步与上采样的的特征融合。

第二部分加强特征提取,将第一步获得的5个特征层进行上采样,并进行特征融合(即在channel通道维度进行堆叠),获得一个最终的,融合所有特征的有效特征层。

第三部分预测部分,通过获得的最终有效特征层对每一个点进行分类(即对每个像素点分类)。

 

  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 import numpy as np
  5 import logging
  6 
  7 
  8 class encoder(nn.Module):   # 下采样部分
  9     def __init__(self, in_channels, out_channels):
 10         super(encoder, self).__init__()
 11         self.down_conv = nn.Sequential(     # 两层卷积
 12             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
 13             nn.BatchNorm2d(out_channels),   # 归一化
 14             nn.ReLU(inplace=True),          # 激活函数
 15             nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
 16             nn.BatchNorm2d(out_channels),
 17             nn.ReLU(inplace=True)
 18         )
 19         # ceil_mode参数取整的时候向上取整,该参数默认为False表示取整的时候向下取整
 20         self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)
 21 
 22     def forward(self, x):
 23         out = self.down_conv(x)
 24         out_pool = self.pool(out)
 25         return out, out_pool   # 返回两个 1.两次池化后的,用于与后续上采样相加
 26                                 # 2.池化后的,用于继续做下采样
 27 
 28 
 29 class decoder(nn.Module):   #  上采样部分
 30     def __init__(self, in_channels, out_channels):
 31         super(decoder, self).__init__()
 32         # 反卷积
 33         self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)  # up-conv
 34 
 35         self.up_conv = nn.Sequential(
 36             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
 37             nn.BatchNorm2d(out_channels),
 38             nn.ReLU(inplace=True),
 39             nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
 40             nn.BatchNorm2d(out_channels),
 41             nn.ReLU(inplace=True)
 42         )
 43 
 44     def forward(self, x_copy, x, interpolate=True):
 45         out = self.up(x)
 46         if interpolate:
 47             # 迭代代替填充, 取得更好的结果
 48             out = F.interpolate(out, size=(x_copy.size(2), x_copy.size(3)),
 49                                 mode="bilinear", align_corners=True
 50                                 )
 51         else:
 52             # 如果填充物体积大小不同
 53             diffY = x_copy.size()[2] - x.size()[2]
 54             diffX = x_copy.size()[3] - x.size()[3]
 55             out = F.pad(out, (diffX // 2, diffX - diffX // 2, diffY, diffY - diffY // 2))
 56         # 连接
 57         out = torch.cat([x_copy, out], dim=1)  # 连接 按通道数叠加相当于罗列,非拼接
 58         out_conv = self.up_conv(out)  # 进行两次卷积
 59         return out_conv
 60 
 61 
 62 class BaseModel(nn.Module):
 63     def __init__(self):
 64         super(BaseModel, self).__init__()
 65         self.logger = logging.getLogger(self.__class__.__name__)
 66 
 67     def forward(self):
 68         raise NotImplementedError
 69 
 70     def summary(self):
 71         model_parameters = filter(lambda p: p.requires_grad, self.parameters())
 72         nbr_params = sum([np.prod(p.size()) for p in model_parameters])
 73         self.logger.info(f'Nbr of trainable parametersL {nbr_params}')
 74 
 75     def __str__(self):
 76         model_parameters = filter(lambda p: p.requires_grad, self.parameters())
 77         nbr_params = sum([np.prod(p.size()) for p in model_parameters])
 78         return super(BaseModel, self).__str__() + f"\nNbr of trainable parameters: {nbr_params}"
 79 
 80 
 81 class UNet(BaseModel):
 82     def __init__(self, num_classes, in_channels=3, freeze_bn=False, **_):
 83         super(UNet, self).__init__()
 84         self.down1 = encoder(in_channels, 64)
 85         self.down2 = encoder(64, 128)
 86         self.down3 = encoder(128, 256)
 87         self.down4 = encoder(256, 512)
 88         self.middle_conv = nn.Sequential(
 89             nn.Conv2d(512, 1024, kernel_size=3, padding=1),
 90             nn.BatchNorm2d(1024),
 91             nn.ReLU(inplace=True),
 92             nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
 93             nn.ReLU(inplace=True)
 94         )
 95 
 96         self.up1 = decoder(1024, 512)
 97         self.up2 = decoder(512, 256)
 98         self.up3 = decoder(256, 128)
 99         self.up4 = decoder(128, 64)
100         self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
101         self._initalize_weights()
102         if freeze_bn:
103             self.freeze_bn()
104 
105     def _initalize_weights(self):
106         for module in self.modules():
107             if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
108                 nn.init.kaiming_normal_(module.weight)
109                 if module.bias is not None:
110                     module.bias.data.zero_()
111             elif isinstance(module, nn.BatchNorm2d):
112                 module.weight.data.fill_(1)
113                 module.bias.data.zero_()
114 
115     def forward(self, x):
116         x1, x = self.down1(x)
117         x2, x = self.down2(x)
118         x3, x = self.down3(x)
119         x4, x = self.down4(x)
120         x = self.middle_conv(x)
121         x = self.up1(x4, x)
122         x = self.up2(x3, x)
123         x = self.up3(x2, x)
124         x = self.up4(x1, x)
125         x = self.final_conv(x)
126         return x
127 
128     def get_backbone_params(self):
129         # There is no backbone for unet, all the parameters are trained from scratch
130         return []
131 
132     def get_decoder_params(self):
133         return self.parameters()
134 
135     def freeze_bn(self):
136         for module in self.modules():
137             if isinstance(module, nn.BatchNorm2d):
138                 module.eval()

 

 

 

 

 

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM