tensorflow 数据预处理


import  tensorflow as tf

from tensorflow import keras

def preprocess(x,y):
x = tf.cast(x, dtype = tf.float32) /255.
y = tf.cast(y, dtype = tf.int64)

y = tf.one_hot(y,depth = 10)
print('y shape :',y.shape)
return x,y

(x,y),(x_test,y_test) = keras.datasets.fashion_mnist.load_data()
db = tf.data.Dataset.from_tensor_slices((x,y))

db2 = db.map(preprocess).shuffle(60000).batch(100)

res = next(iter(db2))

print('res[0] shape',res[0].shape)
print('res[1] shape',res[1].shape)


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM