最近一直在學pytorch,copy了幾個經典的入門問題。現在作一下總結。
首先,做的小項目主要有
分類問題:Mnist手寫體識別、FashionMnist識別、貓狗大戰
語義分割:Unet分割肝臟圖像、遙感圖像
先把語義分割的心得總結一下,目前只是一部分,以后還會隨着學習的深入慢慢往里面加新的感悟。
1)對於二分類問題
1. Unet輸出channel:對於二分類問題,類別數為2,channel為1,用uint8的單通道灰度圖像表示類別就行(0/1)。
2. label是單通道灰度圖像,直接傳給損失函數。
3. 損失函數:nn.sigmoid + nn.BCELoss / nn.BCEWithLogitsLoss,此時計算loss的ouput和label維度應該保持一致。batchsize*1*h*w
2)對於多分類問題
1. Unet輸出channel: 輸出channel是類別數。網絡的輸入是img,網絡的輸出是one hot編碼的多通道圖像。
2. Label是單通道灰度圖像,不同的灰度級表示不同的類別。用於傳給損失函數,計算Loss。
具體操作方面,第一步有人說先將Label進行one hot編碼(即轉換成多通道圖,一個通道一個類別),這樣才能用交叉熵計算損失;也有人說不需要one hot編碼,直接把單通道Label作為損失函數的Label。
其實這兩個人說的都不錯,但第一個人並沒有用Pytorch做,而第二個人是用Pytorch和nn.CrossEntropyLoss計算損失的。
在多分類問題中,當損失函數為nn.CrossEntropyLoss()時,它會自動把標簽轉換成one hot形式。所以,我們在運用交叉熵損失函數時不必將標簽也轉換成onehot形式。在用到這種損失函數時,直接把單通道Label作為損失函數的Label即可,而網絡輸入的img得到的輸出是one hot編碼格式。最后為了可視化輸出,用argmax取到索引,把多通道圖片轉換成單通道圖片(不同灰度級表示不同類別),再用索引對應的RGB顏色表解碼(偽彩色映射)得到分割圖。
ps. 總結一下。因為單通道的Label只是用來計算Loss的,而輸入圖片(img)到網絡的輸出又是多通道圖片(One hot),所以為了計算損失函數,Label傳遞給損失函數前是肯定要one hot一下的,只是用nn.CrossEntropyLoss時,Label自動one hot了,所以不需要你手動去轉換了。此外CrossEntropyLoss還內置了softmax函數,而BCELOSS卻沒有內置sigmoid函數,所以在網絡輸出層中,如果用前者不需要加softmax層,而后者需要加sigmoid層。
3. 損失函數:nn.CrossEntorpyLoss計算。此時計算loss的output維度為batchsize*categories*h*w,label為batchsize*h*w。此外這個損失函數內置了softmax運算。
4. 此外,這種多分類的方法有時候精度相對不高,可以轉化成多個二分類問題,最后合成在一起。
3)test時有時會取torch.argmax/torch.max來得到pred_label的索引,用於計算accuracy。這點圖像分類方面用的比較多。
train的時候一般不需要這個,直接輸入模型的輸出和Label計算Loss,再反向傳遞就可以。
pred_y = torch.max(test_output, 1)[1].data.numpy() #返回每一行中最大值的那個元素,且返回其索引(返回最大元素在這一行的列索引)
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0)) #准確數/batch_size,計算准確率
4)predict出圖部分。(講一點自己的看法,可能不太對)
對於二分類問題,經過sigmoid輸出float類型的概率可以直接可視化,這種情況下mask的精確度不高但是很方便,這取決於你的需求;也可以設定閾值二分類0/1再映射到255。
對於多分類問題,經過softmax輸出概率后有三種方法選擇,1是設置閾值,大於閾值為1,小於閾值為0,得到的是多通道圖像(感覺這樣閾值影響結果很大);2是對model的輸出按channel取argmax,得到的應該是單通道的圖像,索引對應channnel,這種情況下不需要用softmax取概率,只需要對原始輸出取argmax就可以;3是直接可視化softmax輸出的概率。
