tensorflow 數據預處理


import  tensorflow as tf

from tensorflow import keras

def preprocess(x,y):
x = tf.cast(x, dtype = tf.float32) /255.
y = tf.cast(y, dtype = tf.int64)

y = tf.one_hot(y,depth = 10)
print('y shape :',y.shape)
return x,y

(x,y),(x_test,y_test) = keras.datasets.fashion_mnist.load_data()
db = tf.data.Dataset.from_tensor_slices((x,y))

db2 = db.map(preprocess).shuffle(60000).batch(100)

res = next(iter(db2))

print('res[0] shape',res[0].shape)
print('res[1] shape',res[1].shape)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM