遙感圖像多類別語義分割(基於Pytorch-Unet)


遙感圖像多類別語義分割(基於Pytorch-Unet)
前言

​ 去年前就對這方面感興趣了,但是當時只實現了二分類的語義分割,對多類別的語義分割沒有研究。這一塊,目前還是挺熱門的,從FCNUnetdeeplabv3+,模型也是不斷更迭。

思路
  1. 首先復現了FCN(VOC2012)的語義分割代碼,大概了解了布局。
  2. 然后對二分類的代碼進行了修改(基於Pytorch-Unet
核心代碼與步驟講解
  1. dataloader讀取

    import torch.utils.data as data
    import PIL.Image as Image
    import os
    import numpy as np
    import torch
    
    def make_dataset(root1, root2):
        '''
        @func: 讀取數據,存入列表
        @root1: src路徑
        @root2: label路徑
        '''
        imgs = []                                    #遍歷文件夾,添加圖片和標簽圖片路徑到列表
        for i in range(650, 811):
            img = os.path.join(root1, "%s.png" % i)
            mask = os.path.join(root2, "%s.png" % i)
            imgs.append((img, mask))
        return imgs
    
    
    class LiverDataset(data.Dataset):
        '''
        @root1
        @root2
        @transform: 對src做歸一化和標准差處理, 數據最后轉換成tensor
        @target_transform: 不做處理, label為0/1/2/3(long型)..., 數據最后轉換成tensor
        '''
        def __init__(self, root1, root2, transform=None, target_transform=None):
            imgs = make_dataset(root1, root2)             
            self.imgs = imgs
            self.transform = transform
            self.target_transform = target_transform
    
        def __getitem__(self, index):
            x_path, y_path = self.imgs[index]
            img_x = Image.open(x_path)    
            img_y = Image.open(y_path)
            if self.transform is not None:
                img_x = self.transform(img_x)
            if self.target_transform is not None:
                img_y = self.target_transform(img_y)
            else:
                img_y = np.array(img_y) # PIL -> ndarry
                img_y = torch.from_numpy(img_y).long() 
            return img_x, img_y
    
        def __len__(self):
            return len(self.imgs)
    

    這一步里至關重要的就是transform部分。當src是rgb圖片,label是0、1、2...單通道灰度圖類型(一個值代表一個類別)時。src做歸一化和標准差處理,可以提升運算效率和准確性。label則不做處理,轉換成long就好。

    1. Unet模型搭建
    import torch.nn as nn
    import torch
    from torch import autograd
    
    class DoubleConv(nn.Module):
        def __init__(self, in_ch, out_ch):
            super(DoubleConv, self).__init__()
            self.conv = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )
    
        def forward(self, input):
            return self.conv(input)
    
    
    class Unet(nn.Module):
        def __init__(self, in_ch, out_ch):
            super(Unet, self).__init__()
    
            self.conv1 = DoubleConv(in_ch, 64)
            self.pool1 = nn.MaxPool2d(2)
            self.conv2 = DoubleConv(64, 128)
            self.pool2 = nn.MaxPool2d(2)
            self.conv3 = DoubleConv(128, 256)
            self.pool3 = nn.MaxPool2d(2)
            self.conv4 = DoubleConv(256, 512)
            self.pool4 = nn.MaxPool2d(2)
            self.conv5 = DoubleConv(512, 1024)
            self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
            self.conv6 = DoubleConv(1024, 512)
            self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
            self.conv7 = DoubleConv(512, 256)
            self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
            self.conv8 = DoubleConv(256, 128)
            self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
            self.conv9 = DoubleConv(128, 64)
            self.conv10 = nn.Conv2d(64, out_ch, 1)
    
        def forward(self, x):
            c1 = self.conv1(x)
            p1 = self.pool1(c1)
            c2 = self.conv2(p1)
            p2 = self.pool2(c2)
            c3 = self.conv3(p2)
            p3 = self.pool3(c3)
            c4 = self.conv4(p3)
            p4 = self.pool4(c4)
            c5 = self.conv5(p4)
            up_6 = self.up6(c5)
            merge6 = torch.cat([up_6, c4], dim=1)
            c6 = self.conv6(merge6)
            up_7 = self.up7(c6)
            merge7 = torch.cat([up_7, c3], dim=1)
            c7 = self.conv7(merge7)
            up_8 = self.up8(c7)
            merge8 = torch.cat([up_8, c2], dim=1)
            c8 = self.conv8(merge8)
            up_9 = self.up9(c8)
            merge9 = torch.cat([up_9, c1], dim=1)
            c9 = self.conv9(merge9)
            c10 = self.conv10(c9)
            return c10
    
    1. 務必注意,多標簽分類輸出不做概率化處理(softmax)。原因是后面會用nn.CrossEntropyLoss()計算loss,該函數會自動將net()的輸出softmax以及lognllloss()運算。

    2. 然而,當二分類的時候,如果計算損失用的是nn.BCELoss(),由於該函數並未做概率化處理,所以需要單獨運算sigmoid,通常會在Unet模型的末尾輸出

    3. train & test

    這段比較重要,拆成幾段來講。

    最重要的是nn.CrossEntropyLoss(outputs, label)的輸入參數

    outputs: net()輸出的結果,在多分類中是沒有概率化的值

    label: dataloader讀取的標簽,此處是單通道灰度數組(0/1/2/3...)。

    這里CrossEntropyLoss函數

    對outputs做softmax + log + nllloss()處理;

    對label做one-hot encoded(轉換成多維度的0/1矩陣數組,再參與運算)。

    # 1. train
    def train_model(model, criterion, optimizer, dataload, num_epochs=5):
        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs-1))
            print('-' * 10)
            dt_size = len(dataload.dataset)
            epoch_loss = 0
            step = 0
            for x, y in dataload:
                step += 1
                inputs = x.to(device)
                labels = y.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                
                # 可視化輸出, 用於debug
                # probs1 = F.softmax(outputs, dim=1)  # 1 7 256 256
                # probs = torch.argmax(probs1, dim=1) # 1 1 256 256 
                # print(0 in probs)
                # print(1 in probs)
                # print(2 in probs)
                # print(3 in probs)
                # print(4 in probs)
                # print(5 in probs)
                # print(probs.max())
                # print(probs.min())
                # print(probs)
    
                # print("\n")
                # print(labels.max())
                # print(labels.min())
    
                # labels 1X256X256
                # outputs 1X7X256X256
                loss = criterion(outputs, labels) 
     # crossentropyloss時outputs會自動softmax,不需要手動計算 / 之前bceloss計算sigmoid是因為bceloss不包含sigmoid函數,需要自行添加
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
            print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
        torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
        return model
    
    def train():
        model = Unet(3, 7).to(device)
        batch_size = args.batch_size
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters())
        liver_dataset = LiverDataset("data/多分類/src", "data/多分類/label", transform=x_transforms, target_transform=y_transforms)
        dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
        train_model(model, criterion, optimizer, dataloaders)
    
    # 2. transform使用pytorch內置函數
    x_transforms = transforms.Compose([
        transforms.ToTensor(),  # -> [0,1]
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 
    ])
    y_transforms = None  # label不做處理
    
    # 3.test,輸出結果可視化
    def test():
        model = Unet(3, 7)
        model.load_state_dict(torch.load(args.ckp, map_location='cpu'))
        liver_dataset = LiverDataset("data/多分類/src", "data/多分類/label", transform=x_transforms, target_transform=y_transforms)
        dataloaders = DataLoader(liver_dataset, batch_size=1)
        model.eval()
        import matplotlib.pyplot as plt
        plt.ion()
        k = 0
        with torch.no_grad():
            for x, _ in dataloaders:
                y = model(x)  
                
      # 將網絡輸出的數值轉換成概率化數組,再取較大值對應的Index,最后去除第一維維度
                y = F.softmax(y, dim=1)  # 1 7 256 256
                y = torch.argmax(y, dim=1) # 1 1 256 256
                y = torch.squeeze(y).numpy() # 256 256
                
                plt.imshow(y)
                
                # debug
                print(y.max())
                print(y.min())
                print("\n")
                
                skimage.io.imsave('E:/Tablefile/u_net_liver-master_multipleLabels/savetest/{}.jpg'.format(k), y)
                plt.pause(0.1)
                k = k+1
            plt.show()
    
需要注意的地方
  1. 損失函數的選取。

    二分類用BCELoss;多分類用CrossEntropyLoss。

    BCELoss沒有做概率化運算(sigmoid)

    CrossEntropyLoss做了softmax + log + nllloss

  2. transform

    src圖片做歸一化和均值/標准差處理

    label不做處理(單通道數組,0/1/2/3...數值代表類別)

  3. 預測結果不好有可能是loss計算錯誤的問題,也可能是數據集標注的不夠好

  4. 注意計算loss之前的squeeze()函數,用於去掉冗余的維度,使得數組是loss函數需要的shape。(注:BCELoss與CrossEntropy對label的shape要求不同)

  5. 二分類在預測時,net()輸出先做sigmoid()概率化處理,然后大於0.5為1,小於0.5為0

結果展示
image-20210119002138462
后記
  1. 還需多復現幾個語義分割模型(deeplabv3+/segnet/fcn.../unet+)

  2. 理解模型架構卷積、池化、正則化的具體含義

  3. 掌握調參的技巧(優化器、學習率等)

  4. 掌握遷移學習的方法,節省運算時長


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM