# 加載圖片
data = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = data.load_data()
plt.imshow(x_train[0], cmap='gray')
# 創建10個文件夾存放每一類圖片
for i in range(10):
os.makedirs(f"../datas/mnist/train/{i}")
os.makedirs(f"../datas/mnist/test/{i}")
# 保存圖片並生成圖像列表
# 訓練集數據
with open("../datas/mnist/train/image_list.txt", 'w') as img_list:
i = 1
for img, label in zip(x_train, y_train):
img = Image.fromarray(img) # 將array轉化成圖片
img_save_path = f"../datas/mnist/train/{label}/{i}.jpg" # 圖片保存路徑
img.save(img_save_path) # 保存圖片
img_list.write(img_save_path + "\t" + str(label) + "\n")
i += 1
# 測試集數據
with open("../datas/mnist/test/image_list.txt", 'w') as img_list:
i = 1
for img, label in zip(x_test, y_test):
img = Image.fromarray(img) # 將array轉化成圖片
img_save_path = f"../datas/mnist/test/{label}/{i}.jpg" # 圖片保存路徑
img.save(img_save_path)
img_list.write(img_save_path + "\t" + str(label) + "\n")
i += 1