Pytorch錯誤Expected input batch_size (324) to match target batch_size (4) Log In


參考鏈接:

https://blog.csdn.net/qq_41429220/article/details/104973805

Pytorch Error: ValueError: Expected input batch_size (324) to match target batch_size (4) Log In

1.ERROR原因

使用pytorch訓練一個自定義的模型,參照網上的博客直接照搬網絡,但是在修改自定義數據集時,出現這個錯誤。很明顯是一個圖像參數不匹配問題,自定義數據集的圖片大小規格不統一且與網絡接受的大小不匹配。

ValueError: Expected input batch_size (324) to match target batch_size (4) Log In

2.解決思路

首先,在錯誤的網絡結構處前后加入print來查看網絡結構。

# 構建CNN模型
class CNNNet(nn.Module):
    
    def __init__(self):
        super(CNNNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 128, 5)
        self.fc1 = nn.Linear(128*53*53, 1024)
        self.fc2 = nn.Linear(1024, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
#         print(x.shape)
        x = x.view(-1, 128*53*53)
#         print(x.shape)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

即我注釋的這個地方,可以得到輸入前的數據格式。

torch.Size([4, 128, 53, 53])

根據輸出的形狀來更改view里的參數。

x = x.view(-1, 128 * 53 * 53)

后面的Linear層也需要對應修改,使其與數據輸入匹配:

self.fc1 = nn.Linear(128 * 53 * 53, 1024)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM