遇到的問題
數據是png圖像的時候,如果用PIL讀取圖像,獲得的是單通道的,不是多通道的。雖然使用opencv讀取圖片可以獲得三通道圖像數據,如下:
def __getitem__(self, idx):
image_root = self.train_image_file_paths[idx]
image_name = image_root.split(os.path.sep)[-1]
image = cv.imread(image_root)
if self.transform is not None:
image = self.transform(image)
label = ohe.encode(image_name.split('_')[0])
return image, label
但是會出現報錯:
TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>
File "c:/Users/pprp/Desktop/pytorch-captcha-recognition-master/captcha_train.py", line 77, in <module>
main(args)
File "c:/Users/pprp/Desktop/pytorch-captcha-recognition-master/captcha_train.py", line 47, in main
predict_labels = cnn(images)
File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\models\resnet.py", line 192, in forward
x = self.conv1(x)
File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\conv.py", line 338, in forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 64 3 7 7, expected input[64, 60, 160, 3] to have 3 channels, but got 60 channels instead
最終解決方案:
class mydataset(Dataset):
def __init__(self, folder, transform=None):
self.train_image_file_paths = [os.path.join(folder, image_file) for image_file in os.listdir(folder)]
self.transform = transforms.Compose([
transforms.ToTensor(), # 轉化為pytorch中的tensor
transforms.Lambda(lambda x: x.repeat(1,1,1)), # 由於圖片是單通道的,所以重疊三張圖像,獲得一個三通道的數據
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]) # 主要改這個地方
def __len__(self):
return len(self.train_image_file_paths)
def __getitem__(self, idx):
image_root = self.train_image_file_paths[idx]
image_name = image_root.split(os.path.sep)[-1]
image = Image.open(image_root)
if self.transform is not None:
image = self.transform(image)
label = ohe.encode(image_name.split('_')[0])
return image, label
pytorch transform 知識點:https://blog.csdn.net/u011995719/article/details/85107009
PIL PNG格式通道問題的解決方法 : https://www.cnblogs.com/wzjbg/p/8516531.html