Python數據預處理之打亂數據集


import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/',one_hot=True)
train_imgs = mnist.train.images
train_labels = mnist.train.labels
test_imgs = mnist.test.images
test_label_imgs = mnist.test.labels
# 取訓練數據的20%
validate_datasets = 0.2
# 打亂的索引序列
permutation = np.random.permutation(train_labels.shape[0])
validate_indexs = permutation[:int(train_labels.shape[0]*validate_datasets)]
train_indexs = permutation[int(train_labels.shape[0]*validate_datasets):]

x_train_imgs = train_imgs[train_indexs,:]
y_train_labels = train_labels[train_indexs,:]

validate_imgs = train_imgs[validate_indexs,:]
validate_labels = train_labels[validate_indexs,:]

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM