遙感圖像多類別語義分割(基於Pytorch-Unet)
前言
去年前就對這方面感興趣了,但是當時只實現了二分類的語義分割,對多類別的語義分割沒有研究。這一塊,目前還是挺熱門的,從FCN到Unet到deeplabv3+,模型也是不斷更迭。
思路
- 首先復現了
FCN(VOC2012)的語義分割代碼,大概了解了布局。 - 然后對二分類的代碼進行了修改(基於
Pytorch-Unet)
核心代碼與步驟講解
-
dataloader讀取
import torch.utils.data as data import PIL.Image as Image import os import numpy as np import torch def make_dataset(root1, root2): ''' @func: 讀取數據,存入列表 @root1: src路徑 @root2: label路徑 ''' imgs = [] #遍歷文件夾,添加圖片和標簽圖片路徑到列表 for i in range(650, 811): img = os.path.join(root1, "%s.png" % i) mask = os.path.join(root2, "%s.png" % i) imgs.append((img, mask)) return imgs class LiverDataset(data.Dataset): ''' @root1 @root2 @transform: 對src做歸一化和標准差處理, 數據最后轉換成tensor @target_transform: 不做處理, label為0/1/2/3(long型)..., 數據最后轉換成tensor ''' def __init__(self, root1, root2, transform=None, target_transform=None): imgs = make_dataset(root1, root2) self.imgs = imgs self.transform = transform self.target_transform = target_transform def __getitem__(self, index): x_path, y_path = self.imgs[index] img_x = Image.open(x_path) img_y = Image.open(y_path) if self.transform is not None: img_x = self.transform(img_x) if self.target_transform is not None: img_y = self.target_transform(img_y) else: img_y = np.array(img_y) # PIL -> ndarry img_y = torch.from_numpy(img_y).long() return img_x, img_y def __len__(self): return len(self.imgs)這一步里至關重要的就是
transform部分。當src是rgb圖片,label是0、1、2...單通道灰度圖類型(一個值代表一個類別)時。src做歸一化和標准差處理,可以提升運算效率和准確性。label則不做處理,轉換成long就好。- Unet模型搭建
import torch.nn as nn import torch from torch import autograd class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, input): return self.conv(input) class Unet(nn.Module): def __init__(self, in_ch, out_ch): super(Unet, self).__init__() self.conv1 = DoubleConv(in_ch, 64) self.pool1 = nn.MaxPool2d(2) self.conv2 = DoubleConv(64, 128) self.pool2 = nn.MaxPool2d(2) self.conv3 = DoubleConv(128, 256) self.pool3 = nn.MaxPool2d(2) self.conv4 = DoubleConv(256, 512) self.pool4 = nn.MaxPool2d(2) self.conv5 = DoubleConv(512, 1024) self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2) self.conv6 = DoubleConv(1024, 512) self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2) self.conv7 = DoubleConv(512, 256) self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2) self.conv8 = DoubleConv(256, 128) self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2) self.conv9 = DoubleConv(128, 64) self.conv10 = nn.Conv2d(64, out_ch, 1) def forward(self, x): c1 = self.conv1(x) p1 = self.pool1(c1) c2 = self.conv2(p1) p2 = self.pool2(c2) c3 = self.conv3(p2) p3 = self.pool3(c3) c4 = self.conv4(p3) p4 = self.pool4(c4) c5 = self.conv5(p4) up_6 = self.up6(c5) merge6 = torch.cat([up_6, c4], dim=1) c6 = self.conv6(merge6) up_7 = self.up7(c6) merge7 = torch.cat([up_7, c3], dim=1) c7 = self.conv7(merge7) up_8 = self.up8(c7) merge8 = torch.cat([up_8, c2], dim=1) c8 = self.conv8(merge8) up_9 = self.up9(c8) merge9 = torch.cat([up_9, c1], dim=1) c9 = self.conv9(merge9) c10 = self.conv10(c9) return c10-
務必注意,多標簽分類輸出不做概率化處理(
softmax)。原因是后面會用nn.CrossEntropyLoss()計算loss,該函數會自動將net()的輸出做softmax以及log和nllloss()運算。 -
然而,當二分類的時候,如果計算損失用的是
nn.BCELoss(),由於該函數並未做概率化處理,所以需要單獨運算sigmoid,通常會在Unet模型的末尾輸出。 -
train & test
這段比較重要,拆成幾段來講。
最重要的是
nn.CrossEntropyLoss(outputs, label)的輸入參數outputs: net()輸出的結果,在多分類中是沒有概率化的值。label: dataloader讀取的標簽,此處是單通道灰度數組(0/1/2/3...)。這里
CrossEntropyLoss函數對outputs做softmax + log + nllloss()處理;
對label做one-hot encoded(轉換成多維度的0/1矩陣數組,再參與運算)。
# 1. train def train_model(model, criterion, optimizer, dataload, num_epochs=5): for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs-1)) print('-' * 10) dt_size = len(dataload.dataset) epoch_loss = 0 step = 0 for x, y in dataload: step += 1 inputs = x.to(device) labels = y.to(device) optimizer.zero_grad() outputs = model(inputs) # 可視化輸出, 用於debug # probs1 = F.softmax(outputs, dim=1) # 1 7 256 256 # probs = torch.argmax(probs1, dim=1) # 1 1 256 256 # print(0 in probs) # print(1 in probs) # print(2 in probs) # print(3 in probs) # print(4 in probs) # print(5 in probs) # print(probs.max()) # print(probs.min()) # print(probs) # print("\n") # print(labels.max()) # print(labels.min()) # labels 1X256X256 # outputs 1X7X256X256 loss = criterion(outputs, labels) # crossentropyloss時outputs會自動softmax,不需要手動計算 / 之前bceloss計算sigmoid是因為bceloss不包含sigmoid函數,需要自行添加 loss.backward() optimizer.step() epoch_loss += loss.item() print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item())) print("epoch %d loss:%0.3f" % (epoch, epoch_loss)) torch.save(model.state_dict(), 'weights_%d.pth' % epoch) return model def train(): model = Unet(3, 7).to(device) batch_size = args.batch_size criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters()) liver_dataset = LiverDataset("data/多分類/src", "data/多分類/label", transform=x_transforms, target_transform=y_transforms) dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=1) train_model(model, criterion, optimizer, dataloaders)# 2. transform使用pytorch內置函數 x_transforms = transforms.Compose([ transforms.ToTensor(), # -> [0,1] transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) y_transforms = None # label不做處理# 3.test,輸出結果可視化 def test(): model = Unet(3, 7) model.load_state_dict(torch.load(args.ckp, map_location='cpu')) liver_dataset = LiverDataset("data/多分類/src", "data/多分類/label", transform=x_transforms, target_transform=y_transforms) dataloaders = DataLoader(liver_dataset, batch_size=1) model.eval() import matplotlib.pyplot as plt plt.ion() k = 0 with torch.no_grad(): for x, _ in dataloaders: y = model(x) # 將網絡輸出的數值轉換成概率化數組,再取較大值對應的Index,最后去除第一維維度 y = F.softmax(y, dim=1) # 1 7 256 256 y = torch.argmax(y, dim=1) # 1 1 256 256 y = torch.squeeze(y).numpy() # 256 256 plt.imshow(y) # debug print(y.max()) print(y.min()) print("\n") skimage.io.imsave('E:/Tablefile/u_net_liver-master_multipleLabels/savetest/{}.jpg'.format(k), y) plt.pause(0.1) k = k+1 plt.show()
需要注意的地方
-
損失函數的選取。
二分類用BCELoss;多分類用CrossEntropyLoss。
BCELoss沒有做概率化運算(sigmoid)
CrossEntropyLoss做了
softmax + log + nllloss -
transform
src圖片做歸一化和均值/標准差處理
label不做處理(單通道數組,0/1/2/3...數值代表類別)
-
預測結果不好有可能是loss計算錯誤的問題,也可能是數據集標注的不夠好
-
注意計算loss之前的squeeze()函數,用於去掉冗余的維度,使得數組是loss函數需要的shape。(注:BCELoss與CrossEntropy對label的shape要求不同)
-
二分類在預測時,net()輸出先做sigmoid()概率化處理,然后大於0.5為1,小於0.5為0。
結果展示
后記
-
還需多復現幾個語義分割模型(deeplabv3+/segnet/fcn.../unet+)
-
理解模型架構卷積、池化、正則化的具體含義
-
掌握調參的技巧(優化器、學習率等)
-
掌握遷移學習的方法,節省運算時長
