用torchvision.datasets.ImageFolder加載圖片數據集


一、項目結構

 

 

 

二、代碼

 1 data_loader = torch.utils.data.DataLoader(
 2     torchvision.datasets.ImageFolder('traing_dataset',
 3        transform=torchvision.transforms.Compose([
 4                 torchvision.transforms.Resize([28, 28]),       # 裁剪圖片
 5                 torchvision.transforms.Grayscale(1),           # 單通道
 6                 torchvision.transforms.ToTensor(),             # 將圖片數據轉成tensor格式
 7                 torchvision.transforms.Normalize(              # 歸一化
 8                      (0.1307,), (0.3081,))
 9                 ])),
10     batch_size=10, shuffle=False)                  # 10張圖片

 

三、顯示效果

 1 def plot_image(img, label, name):
 2     fig = plt.figure()
 3     for i in range(6):                                    # 只顯示6張  4         plt.subplot(2, 3, i+1)                               # 2行3列第i+1張  5         plt.tight_layout()
 6         plt.imshow(img[i][0]*0.3081+0.1307, cmap='Greys', interpolation='none')
 7         plt.title("{}:{}".format(name, label[i].item()))                # 標題名稱  8         plt.xticks([])
 9         plt.yticks([])
10     plt.show()
11 
12 x, y = next(iter(data_loader))                                # 文件夾的名稱即為圖片的label 13 print(x.shape, y.shape)
14 plot_image(x, y, 'image')

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM