Unet-語義分割


1.何為語義分割?

語義分割結合了目標檢測、圖像分類和圖像分割等技術。圖片輸入,通過語義分割模型對原有圖像分割成具有一定語義含義的區域塊,識別出每個區域塊語義類別,最終得到與原圖像等大小具有逐像素語義標注的分割圖像。

四幅圖分別代表(a)目標分類,(b)識別與定位,(c)語義分割,(d)實例分割

語義分割實則就是把整張圖片中的像素點分類。

原圖像                              標簽                

             

由圖像可知,共分為了四類:1.花瓣(紅色),2.葉子(黃色),3.莖稈(綠色),4.背景(黑色)

一般情況下,數據集圖片大小可能不一,但導入網絡的數據集需要大小相同。此時需要對數據集進行圖片切割預處理。將數據集圖片等大小切割。

(跑一邊數據集,找到其中其中最小的W,H,選取切割大小 w<=W,h<=H)

 

 2.Unet

unet對稱語義分割模型

 

 主要由三部分組成。

第一部分主干特征提取(即灰線),利用主干部分獲得一個又一個特征層,其主干特征提取部分與VGG16類似,利用此部分共獲得5個初步有效特征層,用於下一步與上采樣的的特征融合。

第二部分加強特征提取,將第一步獲得的5個特征層進行上采樣,並進行特征融合(即在channel通道維度進行堆疊),獲得一個最終的,融合所有特征的有效特征層。

第三部分預測部分,通過獲得的最終有效特征層對每一個點進行分類(即對每個像素點分類)。

 

  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 import numpy as np
  5 import logging
  6 
  7 
  8 class encoder(nn.Module):   # 下采樣部分
  9     def __init__(self, in_channels, out_channels):
 10         super(encoder, self).__init__()
 11         self.down_conv = nn.Sequential(     # 兩層卷積
 12             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
 13             nn.BatchNorm2d(out_channels),   # 歸一化
 14             nn.ReLU(inplace=True),          # 激活函數
 15             nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
 16             nn.BatchNorm2d(out_channels),
 17             nn.ReLU(inplace=True)
 18         )
 19         # ceil_mode參數取整的時候向上取整,該參數默認為False表示取整的時候向下取整
 20         self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)
 21 
 22     def forward(self, x):
 23         out = self.down_conv(x)
 24         out_pool = self.pool(out)
 25         return out, out_pool   # 返回兩個 1.兩次池化后的,用於與后續上采樣相加
 26                                 # 2.池化后的,用於繼續做下采樣
 27 
 28 
 29 class decoder(nn.Module):   #  上采樣部分
 30     def __init__(self, in_channels, out_channels):
 31         super(decoder, self).__init__()
 32         # 反卷積
 33         self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)  # up-conv
 34 
 35         self.up_conv = nn.Sequential(
 36             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
 37             nn.BatchNorm2d(out_channels),
 38             nn.ReLU(inplace=True),
 39             nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
 40             nn.BatchNorm2d(out_channels),
 41             nn.ReLU(inplace=True)
 42         )
 43 
 44     def forward(self, x_copy, x, interpolate=True):
 45         out = self.up(x)
 46         if interpolate:
 47             # 迭代代替填充, 取得更好的結果
 48             out = F.interpolate(out, size=(x_copy.size(2), x_copy.size(3)),
 49                                 mode="bilinear", align_corners=True
 50                                 )
 51         else:
 52             # 如果填充物體積大小不同
 53             diffY = x_copy.size()[2] - x.size()[2]
 54             diffX = x_copy.size()[3] - x.size()[3]
 55             out = F.pad(out, (diffX // 2, diffX - diffX // 2, diffY, diffY - diffY // 2))
 56         # 連接
 57         out = torch.cat([x_copy, out], dim=1)  # 連接 按通道數疊加相當於羅列,非拼接
 58         out_conv = self.up_conv(out)  # 進行兩次卷積
 59         return out_conv
 60 
 61 
 62 class BaseModel(nn.Module):
 63     def __init__(self):
 64         super(BaseModel, self).__init__()
 65         self.logger = logging.getLogger(self.__class__.__name__)
 66 
 67     def forward(self):
 68         raise NotImplementedError
 69 
 70     def summary(self):
 71         model_parameters = filter(lambda p: p.requires_grad, self.parameters())
 72         nbr_params = sum([np.prod(p.size()) for p in model_parameters])
 73         self.logger.info(f'Nbr of trainable parametersL {nbr_params}')
 74 
 75     def __str__(self):
 76         model_parameters = filter(lambda p: p.requires_grad, self.parameters())
 77         nbr_params = sum([np.prod(p.size()) for p in model_parameters])
 78         return super(BaseModel, self).__str__() + f"\nNbr of trainable parameters: {nbr_params}"
 79 
 80 
 81 class UNet(BaseModel):
 82     def __init__(self, num_classes, in_channels=3, freeze_bn=False, **_):
 83         super(UNet, self).__init__()
 84         self.down1 = encoder(in_channels, 64)
 85         self.down2 = encoder(64, 128)
 86         self.down3 = encoder(128, 256)
 87         self.down4 = encoder(256, 512)
 88         self.middle_conv = nn.Sequential(
 89             nn.Conv2d(512, 1024, kernel_size=3, padding=1),
 90             nn.BatchNorm2d(1024),
 91             nn.ReLU(inplace=True),
 92             nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
 93             nn.ReLU(inplace=True)
 94         )
 95 
 96         self.up1 = decoder(1024, 512)
 97         self.up2 = decoder(512, 256)
 98         self.up3 = decoder(256, 128)
 99         self.up4 = decoder(128, 64)
100         self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
101         self._initalize_weights()
102         if freeze_bn:
103             self.freeze_bn()
104 
105     def _initalize_weights(self):
106         for module in self.modules():
107             if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
108                 nn.init.kaiming_normal_(module.weight)
109                 if module.bias is not None:
110                     module.bias.data.zero_()
111             elif isinstance(module, nn.BatchNorm2d):
112                 module.weight.data.fill_(1)
113                 module.bias.data.zero_()
114 
115     def forward(self, x):
116         x1, x = self.down1(x)
117         x2, x = self.down2(x)
118         x3, x = self.down3(x)
119         x4, x = self.down4(x)
120         x = self.middle_conv(x)
121         x = self.up1(x4, x)
122         x = self.up2(x3, x)
123         x = self.up3(x2, x)
124         x = self.up4(x1, x)
125         x = self.final_conv(x)
126         return x
127 
128     def get_backbone_params(self):
129         # There is no backbone for unet, all the parameters are trained from scratch
130         return []
131 
132     def get_decoder_params(self):
133         return self.parameters()
134 
135     def freeze_bn(self):
136         for module in self.modules():
137             if isinstance(module, nn.BatchNorm2d):
138                 module.eval()

 

 

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM