RuntimeError: multi-target not supported at


1. 出錯代碼行
計算交叉熵是出現異常提示:RuntimeError: multi-target not supported at /opt/conda/conda-bld/pytorch_1549635019666/work/aten/src/THNN/generic/ClassNLLCriterion.c:21

loss = criterion(prediction, target)
2 原因:
CrossEntropyLoss does not expect a one-hot encoded vector as the target, but class indices
pytorch 中計計算交叉熵損失函數時, 輸入的正確 label 不能是 one-hot 格式。函數內部會自己處理成 one hot 格式。所以不需要輸入 [ 0 0 0 0 1],只需要輸入 4 就行。

print(prediction.size())
print(target.size())
print("target = ", target)
loss = criterion(prediction, target)
# 輸出給過如下
torch.Size([2, 145]) # 輸入兩個數據,每個數據的有145個值
torch.Size([2]) # target(ground true) 是兩個值,每個數據一個值
target = tensor([4, 4]) # 兩個數據具體的 target 值,都是4

3. 解決方法:
更改 dataloader 中 dataset 中 def __getitem__(self, index) 返回的 target 的內容(將 one hot 格式改成 數字格式 就行)。
如果 target 的size 不是一維的話,需要添加一行代碼,如下:
target = target.squeeze() # 添加這一行,用於降維度(從 torch.Size([2, 1]) 降成torch.Size([2]) ) ,即 target 從[ [4], [4]] 降成 [ 4, 4 ]
loss = criterion(prediction, target)

4. 總結
pytorch 中使用神經網絡進行多分類時,網路的輸出 prediction 是 one hot 格式,但計算 交叉熵損失函數時,loss = criterion(prediction, target) 的輸入 target 不能是 one hot 格式,直接用數字來表示就行(4 表示 one hot 中的 0 0 0 0 1)。
所以,自己構建數據集,返回的 target 不需要是 one hot 格式。

  

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM