數據集下載地址:
鏈接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw
提取碼:2xq4
之前在:https://www.cnblogs.com/xiximayou/p/12398285.html創建好了數據集,將它上傳到谷歌colab

在colab上的目錄如下:

在utils中的rdata.py定義了讀取該數據集的代碼:
from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import torch #預處理 transform = transforms.Compose([transforms.ToTensor()]) path = "/content/drive/My Drive/colab notebooks/data/dogcat" train_path=path+"/train" test_path=path+"/test" #使用torchvision.datasets.ImageFolder讀取數據集指定train和test文件夾 train_data = torchvision.datasets.ImageFolder(train_path, transform=transform) train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=1) test_data = torchvision.datasets.ImageFolder(test_path, transform=transform) test_loader = DataLoader(test_data, batch_size=32, shuffle=True, num_workers=1) print(train_data.classes) #根據分的文件夾的名字來確定的類別 print(train_data.class_to_idx) #按順序為這些類別定義索引為0,1... print(train_data.imgs) #返回從所有文件夾中得到的圖片的路徑以及其類別 print(test_data.classes) #根據分的文件夾的名字來確定的類別 print(test_data.class_to_idx) #按順序為這些類別定義索引為0,1... print(test_data.imgs) #返回從所有文件夾中得到的圖片的路徑以及其類別
ImageFolder可以讀取我們的train或test下面的文件夾,並為每一個標簽進行編碼,同時將圖片與標簽進行對應。
在test.ipynb中運行rdata.py

說明我們創建的數據集是可以用的了。
有了數據集,接下來就是網絡的搭建以及訓練和測試了。
