這里有三種方式保存模型:
第一種: 只保存網絡參數,適合自己了解網絡結構
第二種: 保存整個網絡,可以完美進行恢復
第三個是保存格式。
第一種方式:
實踐操作:
第二種方式:(存入整個模型)
第三種方式:(存成工業模型)
import tensorflow as tf save_path = 'save_model/' save_path2 = 'save_model2/' save_path3 = 'save_model3/' def preporocess(x,y): x = tf.cast(x,dtype=tf.float32) / 255 x = tf.reshape(x,(-1,28 *28)) # 鋪平 x = tf.squeeze(x,axis=0) # print('里面x.shape:',x.shape) y = tf.cast(y,dtype=tf.int32) y = tf.one_hot(y,depth=10) return x,y def my_create(): # 設置超參 iter_num = 2000 # 迭代次數 lr = 0.01 # 學習率 # 定義模型器和優化器 model = tf.keras.Sequential([ tf.keras.layers.Dense(256, activation='relu'), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(32, activation='relu'), tf.keras.layers.Dense(10) ]) # 優化器 # optimizer = tf.keras.optimizers.SGD(learning_rate=lr) optimizer = tf.keras.optimizers.Adam(learning_rate=lr) # 定義優化器 model.compile(optimizer= optimizer,loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy']) # 定義模型配置 model.fit(db,epochs=2,validation_data=db,validation_freq=2) # 運行模型,參數validation_data是指在哪個測試集上進行測試 model.evaluate(db_test) # 最后打印測試數據相關准確率數據 ################ 模型保存 ################## # # 1. 只存入model的參數 # model_name = 'my_model1.ckpt' # model.save_weights(save_path+model_name) # # 2. 存入整個model # model_name = 'my_model2.h5' # model.save(save_path2 + model_name) # 3. 存成工業模型 tf.saved_model.save(model,save_path3) print('保存工業模型') del model def my_load(): # 設置超參 lr = 0.01 # 學習率 # 定義模型器和優化器 model = tf.keras.Sequential([ tf.keras.layers.Dense(256, activation='relu'), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(32, activation='relu'), tf.keras.layers.Dense(10) ]) # 優化器 # optimizer = tf.keras.optimizers.SGD(learning_rate=lr) optimizer = tf.keras.optimizers.Adam(learning_rate=lr) # 定義優化器 model.compile(optimizer=optimizer, loss=tf.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy']) # 定義模型配置 ################# 加載模型 ############### # # 加載僅有參數的model # model_name = 'my_model1.ckpt' # model.load_weights(save_path+model_name) # print('加載僅有參數的模型') # model.evaluate(db_test) # # 加載整個model # model_name = 'my_model2.h5' # model2 = tf.keras.models.load_model(save_path2 + model_name) # print('加載整個模型') # model2.evaluate(db_test) # 加載工業模型 model2 = tf.saved_model.load(save_path3) print('加載工業模型') f = model2.signatures['serving_default'] print(db_test) print(f(db_test[0])) if __name__ == '__main__': # 數據步驟 # 加載手寫數字數據 mnist = tf.keras.datasets.mnist (train_x, train_y), (test_x, test_y) = mnist.load_data() # 處理數據 # 訓練數據 db = tf.data.Dataset.from_tensor_slices((train_x, train_y)) # 將x,y分成一一對應的元組 db = db.map(preporocess) # 執行預處理函數 db = db.shuffle(60000).batch(20) # 打亂加分組 # 測試數據 db_test = tf.data.Dataset.from_tensor_slices((test_x, test_y)) db_test = db_test.map(preporocess) db_test = db_test.shuffle(10000).batch(10000) # 操作步驟 # my_create() print('_________________----------------------------____________________________') my_load()