在TF1.8之后Keras被當作為一個內置API:tf.keras.
並且之前的下載語句會報錯。
1 mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
下面給出Keras和TensorFlow兩種方式的訓練代碼(附驗證代碼):
Keras:
import numpy as np import matplotlib.pyplot as plt import keras from keras.datasets import mnist from keras.models import Sequential from keras.layers import Dense from keras.optimizers import SGD (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.reshape(60000,784) x_test = x_test.reshape(10000,784) # 歸一化 x_train = x_train/255 x_test = x_test/255 y_train = keras.utils.to_categorical(y_train,10) y_test = keras.utils.to_categorical(y_test,10) model = Sequential() model.add(Dense(512,activation='relu',input_shape=(784,))) model.add(Dense(256,activation="relu")) model.add(Dense(10,activation="softmax")) # 顯示網絡結構 model.summary() model.compile(optimizer=SGD(),loss='categorical_crossentropy',metrics=['accuracy']) model.fit(x_train,y_train,batch_size=64,epochs=5,validation_data=(x_test,y_test)) score = model.evaluate(x_test,y_test) # 輸出 loss 和 accuracy print("loss",score[0]) print("accu",score[1]) # 輸入樣本生成輸出預測。 predictions = model.predict([x_test]) # 預測 print(np.argmax(predictions[23])) # 查看圖片 檢查是否預測正確 plt.imshow(x_test[23]) plt.show()
TensorFlow:
代碼來自TensorFlow官網
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(x_train, y_train, epochs=5) model.evaluate(x_test, y_test) #驗證代碼同上 略
另附Keras與tf.keras的區別(https://www.zhihu.com/question/313111229/answer/606660552)