import numpy as np from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('data/',one_hot=True) train_imgs = mnist.train.images train_labels = mnist.train.labels test_imgs = mnist.test.images test_label_imgs = mnist.test.labels # 取训练数据的20% validate_datasets = 0.2 # 打乱的索引序列 permutation = np.random.permutation(train_labels.shape[0]) validate_indexs = permutation[:int(train_labels.shape[0]*validate_datasets)] train_indexs = permutation[int(train_labels.shape[0]*validate_datasets):] x_train_imgs = train_imgs[train_indexs,:] y_train_labels = train_labels[train_indexs,:] validate_imgs = train_imgs[validate_indexs,:] validate_labels = train_labels[validate_indexs,:]