Pytorch從本地獲取數據集
- 在學習
pytorch
的過程中需要從MNIST
獲取數據集,然而下載是讓人頭疼的事,從網上尋找數據資源比較便捷 - 獲取到的數據如何在
pytorch
中加載呢
1 下載數據集
https://download.csdn.net/download/wangxiaobei2017/12238192
2. 從本地進行數據加載
-
獲取測試集與訓練集
直接運行后,發現依舊是下載數據,那我本地的數據集怎么才能被加載
mnist_train = torchvision.datasets.FashionMNIST(root='./MNIST', train=True, download=True,transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='./MNIST', train=False, download=True,transform=transforms.ToTensor())
- 查找數據源url
按下Ctrl
,左鍵點擊FashionMNIS
,進入mnist.py
,在resources
下可以看到,這里是數據集的下載路徑,需要將其修改為本地文件的路徑
-
查找本地數據源
-
將本地數據源替換之前的路徑
特別要注意后面的None
,這個是md5
校驗碼,如果不填會報錯
-
運行程序,加載數據集
mnist_train = torchvision.datasets.FashionMNIST(root='./MNIST', train=True, download=True,transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root='./MNIST', train=False, download=True,transform=transforms.ToTensor()) print(type(mnist_train)) print(len(mnist_train), len(mnist_test))
完成