1.何为语义分割?
语义分割结合了目标检测、图像分类和图像分割等技术。图片输入,通过语义分割模型对原有图像分割成具有一定语义含义的区域块,识别出每个区域块语义类别,最终得到与原图像等大小具有逐像素语义标注的分割图像。
四幅图分别代表(a)目标分类,(b)识别与定位,(c)语义分割,(d)实例分割
语义分割实则就是把整张图片中的像素点分类。
例
原图像 标签
由图像可知,共分为了四类:1.花瓣(红色),2.叶子(黄色),3.茎秆(绿色),4.背景(黑色)
一般情况下,数据集图片大小可能不一,但导入网络的数据集需要大小相同。此时需要对数据集进行图片切割预处理。将数据集图片等大小切割。
(跑一边数据集,找到其中其中最小的W,H,选取切割大小 w<=W,h<=H)
2.Unet
unet对称语义分割模型
主要由三部分组成。
第一部分主干特征提取(即灰线),利用主干部分获得一个又一个特征层,其主干特征提取部分与VGG16类似,利用此部分共获得5个初步有效特征层,用于下一步与上采样的的特征融合。
第二部分加强特征提取,将第一步获得的5个特征层进行上采样,并进行特征融合(即在channel通道维度进行堆叠),获得一个最终的,融合所有特征的有效特征层。
第三部分预测部分,通过获得的最终有效特征层对每一个点进行分类(即对每个像素点分类)。
1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 import numpy as np 5 import logging 6 7 8 class encoder(nn.Module): # 下采样部分 9 def __init__(self, in_channels, out_channels): 10 super(encoder, self).__init__() 11 self.down_conv = nn.Sequential( # 两层卷积 12 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 13 nn.BatchNorm2d(out_channels), # 归一化 14 nn.ReLU(inplace=True), # 激活函数 15 nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 16 nn.BatchNorm2d(out_channels), 17 nn.ReLU(inplace=True) 18 ) 19 # ceil_mode参数取整的时候向上取整,该参数默认为False表示取整的时候向下取整 20 self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True) 21 22 def forward(self, x): 23 out = self.down_conv(x) 24 out_pool = self.pool(out) 25 return out, out_pool # 返回两个 1.两次池化后的,用于与后续上采样相加 26 # 2.池化后的,用于继续做下采样 27 28 29 class decoder(nn.Module): # 上采样部分 30 def __init__(self, in_channels, out_channels): 31 super(decoder, self).__init__() 32 # 反卷积 33 self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) # up-conv 34 35 self.up_conv = nn.Sequential( 36 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 37 nn.BatchNorm2d(out_channels), 38 nn.ReLU(inplace=True), 39 nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 40 nn.BatchNorm2d(out_channels), 41 nn.ReLU(inplace=True) 42 ) 43 44 def forward(self, x_copy, x, interpolate=True): 45 out = self.up(x) 46 if interpolate: 47 # 迭代代替填充, 取得更好的结果 48 out = F.interpolate(out, size=(x_copy.size(2), x_copy.size(3)), 49 mode="bilinear", align_corners=True 50 ) 51 else: 52 # 如果填充物体积大小不同 53 diffY = x_copy.size()[2] - x.size()[2] 54 diffX = x_copy.size()[3] - x.size()[3] 55 out = F.pad(out, (diffX // 2, diffX - diffX // 2, diffY, diffY - diffY // 2)) 56 # 连接 57 out = torch.cat([x_copy, out], dim=1) # 连接 按通道数叠加相当于罗列,非拼接 58 out_conv = self.up_conv(out) # 进行两次卷积 59 return out_conv 60 61 62 class BaseModel(nn.Module): 63 def __init__(self): 64 super(BaseModel, self).__init__() 65 self.logger = logging.getLogger(self.__class__.__name__) 66 67 def forward(self): 68 raise NotImplementedError 69 70 def summary(self): 71 model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 72 nbr_params = sum([np.prod(p.size()) for p in model_parameters]) 73 self.logger.info(f'Nbr of trainable parametersL {nbr_params}') 74 75 def __str__(self): 76 model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 77 nbr_params = sum([np.prod(p.size()) for p in model_parameters]) 78 return super(BaseModel, self).__str__() + f"\nNbr of trainable parameters: {nbr_params}" 79 80 81 class UNet(BaseModel): 82 def __init__(self, num_classes, in_channels=3, freeze_bn=False, **_): 83 super(UNet, self).__init__() 84 self.down1 = encoder(in_channels, 64) 85 self.down2 = encoder(64, 128) 86 self.down3 = encoder(128, 256) 87 self.down4 = encoder(256, 512) 88 self.middle_conv = nn.Sequential( 89 nn.Conv2d(512, 1024, kernel_size=3, padding=1), 90 nn.BatchNorm2d(1024), 91 nn.ReLU(inplace=True), 92 nn.Conv2d(1024, 1024, kernel_size=3, padding=1), 93 nn.ReLU(inplace=True) 94 ) 95 96 self.up1 = decoder(1024, 512) 97 self.up2 = decoder(512, 256) 98 self.up3 = decoder(256, 128) 99 self.up4 = decoder(128, 64) 100 self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1) 101 self._initalize_weights() 102 if freeze_bn: 103 self.freeze_bn() 104 105 def _initalize_weights(self): 106 for module in self.modules(): 107 if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 108 nn.init.kaiming_normal_(module.weight) 109 if module.bias is not None: 110 module.bias.data.zero_() 111 elif isinstance(module, nn.BatchNorm2d): 112 module.weight.data.fill_(1) 113 module.bias.data.zero_() 114 115 def forward(self, x): 116 x1, x = self.down1(x) 117 x2, x = self.down2(x) 118 x3, x = self.down3(x) 119 x4, x = self.down4(x) 120 x = self.middle_conv(x) 121 x = self.up1(x4, x) 122 x = self.up2(x3, x) 123 x = self.up3(x2, x) 124 x = self.up4(x1, x) 125 x = self.final_conv(x) 126 return x 127 128 def get_backbone_params(self): 129 # There is no backbone for unet, all the parameters are trained from scratch 130 return [] 131 132 def get_decoder_params(self): 133 return self.parameters() 134 135 def freeze_bn(self): 136 for module in self.modules(): 137 if isinstance(module, nn.BatchNorm2d): 138 module.eval()