Python数据预处理之打乱数据集


import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/',one_hot=True)
train_imgs = mnist.train.images
train_labels = mnist.train.labels
test_imgs = mnist.test.images
test_label_imgs = mnist.test.labels
# 取训练数据的20%
validate_datasets = 0.2
# 打乱的索引序列
permutation = np.random.permutation(train_labels.shape[0])
validate_indexs = permutation[:int(train_labels.shape[0]*validate_datasets)]
train_indexs = permutation[int(train_labels.shape[0]*validate_datasets):]

x_train_imgs = train_imgs[train_indexs,:]
y_train_labels = train_labels[train_indexs,:]

validate_imgs = train_imgs[validate_indexs,:]
validate_labels = train_labels[validate_indexs,:]

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM