一、項目結構
二、代碼
1 data_loader = torch.utils.data.DataLoader( 2 torchvision.datasets.ImageFolder('traing_dataset', 3 transform=torchvision.transforms.Compose([ 4 torchvision.transforms.Resize([28, 28]), # 裁剪圖片 5 torchvision.transforms.Grayscale(1), # 單通道 6 torchvision.transforms.ToTensor(), # 將圖片數據轉成tensor格式 7 torchvision.transforms.Normalize( # 歸一化 8 (0.1307,), (0.3081,)) 9 ])), 10 batch_size=10, shuffle=False) # 10張圖片
三、顯示效果
1 def plot_image(img, label, name): 2 fig = plt.figure() 3 for i in range(6): # 只顯示6張 4 plt.subplot(2, 3, i+1) # 2行3列第i+1張 5 plt.tight_layout() 6 plt.imshow(img[i][0]*0.3081+0.1307, cmap='Greys', interpolation='none') 7 plt.title("{}:{}".format(name, label[i].item())) # 標題名稱 8 plt.xticks([]) 9 plt.yticks([]) 10 plt.show() 11 12 x, y = next(iter(data_loader)) # 文件夾的名稱即為圖片的label 13 print(x.shape, y.shape) 14 plot_image(x, y, 'image')