1 from PIL import Image 2 3 try: 4 import urllib.request 5 except ImportError: 6 raise ImportError('You should use Python 3.x') 7 import os.path 8 import gzip 9 import pickle 10 import os 11 import numpy as np 12 13 url_base = 'http://yann.lecun.com/exdb/mnist/' 14 key_file = { 15 'train_img': 'train-images-idx3-ubyte.gz', 16 'train_label': 'train-labels-idx1-ubyte.gz', 17 'test_img': 't10k-images-idx3-ubyte.gz', 18 'test_label': 't10k-labels-idx1-ubyte.gz' 19 } 20 21 dataset_dir = os.path.dirname(os.path.abspath(__file__)) 22 save_file = dataset_dir + "/mnist.pkl" 23 24 train_num = 60000 25 test_num = 10000 26 img_dim = (1, 28, 28) 27 img_size = 784 28 29 30 def _download(file_name): 31 file_path = dataset_dir + "/" + file_name 32 33 if os.path.exists(file_path): 34 return 35 36 print("Downloading " + file_name + " ... ") 37 urllib.request.urlretrieve(url_base + file_name, file_path) 38 print("Done") 39 40 41 def download_mnist(): 42 for v in key_file.values(): 43 _download(v) 44 45 46 def _load_label(file_name): 47 file_path = dataset_dir + "/" + file_name 48 49 print("Converting " + file_name + " to NumPy Array ...") 50 with gzip.open(file_path, 'rb') as f: 51 labels = np.frombuffer(f.read(), np.uint8, offset=8) 52 print("Done") 53 54 return labels 55 56 57 def _load_img(file_name): 58 file_path = dataset_dir + "/" + file_name 59 60 print("Converting " + file_name + " to NumPy Array ...") 61 with gzip.open(file_path, 'rb') as f: 62 data = np.frombuffer(f.read(), np.uint8, offset=16) 63 data = data.reshape(-1, img_size) 64 print("Done") 65 66 return data 67 68 69 def _convert_numpy(): 70 dataset = {} 71 dataset['train_img'] = _load_img(key_file['train_img']) 72 dataset['train_label'] = _load_label(key_file['train_label']) 73 dataset['test_img'] = _load_img(key_file['test_img']) 74 dataset['test_label'] = _load_label(key_file['test_label']) 75 76 return dataset 77 78 79 def init_mnist(): 80 download_mnist() 81 dataset = _convert_numpy() 82 print("Creating pickle file ...") 83 with open(save_file, 'wb') as f: 84 pickle.dump(dataset, f, -1) 85 print("Done!") 86 87 88 def _change_one_hot_label(X): 89 T = np.zeros((X.size, 10)) 90 for idx, row in enumerate(T): 91 row[X[idx]] = 1 92 93 return T 94 95 96 def load_mnist(normalize=True, flatten=True, one_hot_label=False): 97 """读入MNIST数据集 98 99 Parameters 100 ---------- 101 normalize : 将图像的像素值正规化为0.0~1.0 102 one_hot_label : 103 one_hot_label为True的情况下,标签作为one-hot数组返回 104 one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组 105 flatten : 是否将图像展开为一维数组 106 107 Returns 108 ------- 109 (训练图像, 训练标签), (测试图像, 测试标签) 110 """ 111 if not os.path.exists(save_file): 112 init_mnist() 113 114 with open(save_file, 'rb') as f: 115 dataset = pickle.load(f) 116 117 if normalize: 118 for key in ('train_img', 'test_img'): 119 dataset[key] = dataset[key].astype(np.float32) 120 dataset[key] /= 255.0 121 122 if one_hot_label: 123 dataset['train_label'] = _change_one_hot_label(dataset['train_label']) 124 dataset['test_label'] = _change_one_hot_label(dataset['test_label']) 125 126 if not flatten: 127 for key in ('train_img', 'test_img'): 128 dataset[key] = dataset[key].reshape(-1, 1, 28, 28) 129 130 return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) 131 132 def img_show(img): 133 pil_img = Image.fromarray(np.uint8(img)) 134 pil_img.show() 135 136 137 if __name__ == '__main__': 138 init_mnist() 139 (x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False) 140 141 img = x_train[1] 142 label = t_train[1] 143 print(label) # 5 144 145 print(img.shape) # (784,) 146 img = img.reshape(28, 28) # 把图像的形状变为原来的尺寸 147 print(img.shape) # (28, 28) 148 149 img_show(img)