tensorflow2.0——模型保存與加載


這里有三種方式保存模型:

  

    第一種:  只保存網絡參數,適合自己了解網絡結構

    第二種:  保存整個網絡,可以完美進行恢復

    第三個是保存格式。

 

第一種方式:

  

 

   實踐操作:

  

 

 第二種方式:(存入整個模型)

  

 

 第三種方式:(存成工業模型)

  

 

 

import tensorflow as tf

save_path = 'save_model/'
save_path2 = 'save_model2/'
save_path3 = 'save_model3/'

def preporocess(x,y):
    x = tf.cast(x,dtype=tf.float32) / 255
    x = tf.reshape(x,(-1,28 *28))                   #   鋪平
    x = tf.squeeze(x,axis=0)
    # print('里面x.shape:',x.shape)
    y = tf.cast(y,dtype=tf.int32)
    y = tf.one_hot(y,depth=10)
    return x,y

def my_create():
    #   設置超參
    iter_num = 2000  # 迭代次數
    lr = 0.01  # 學習率
    #   定義模型器和優化器
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    #   優化器
    # optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)              #   定義優化器
    model.compile(optimizer= optimizer,loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])       #   定義模型配置
    model.fit(db,epochs=2,validation_data=db,validation_freq=2)          #  運行模型,參數validation_data是指在哪個測試集上進行測試
    model.evaluate(db_test)                                                     #   最后打印測試數據相關准確率數據
################    模型保存    ##################
    # #   1.    只存入model的參數
    # model_name = 'my_model1.ckpt'
    # model.save_weights(save_path+model_name)
    # #   2.      存入整個model
    # model_name = 'my_model2.h5'
    # model.save(save_path2 + model_name)
    #   3.      存成工業模型
    tf.saved_model.save(model,save_path3)
    print('保存工業模型')
    del model
def my_load():
    #   設置超參
    lr = 0.01  # 學習率
    #   定義模型器和優化器
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    #   優化器
    # optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)  # 定義優化器
    model.compile(optimizer=optimizer, loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])  # 定義模型配置
#################   加載模型    ###############
    # #   加載僅有參數的model
    # model_name = 'my_model1.ckpt'
    # model.load_weights(save_path+model_name)
    # print('加載僅有參數的模型')
    # model.evaluate(db_test)
    # #   加載整個model
    # model_name = 'my_model2.h5'
    # model2 = tf.keras.models.load_model(save_path2 + model_name)
    # print('加載整個模型')
    # model2.evaluate(db_test)
    #   加載工業模型
    model2 = tf.saved_model.load(save_path3)
    print('加載工業模型')
    f = model2.signatures['serving_default']
    print(db_test)
    print(f(db_test[0]))
if __name__ == '__main__':
#   數據步驟
    #   加載手寫數字數據
    mnist = tf.keras.datasets.mnist
    (train_x, train_y), (test_x, test_y) = mnist.load_data()
    #   處理數據
    #   訓練數據
    db = tf.data.Dataset.from_tensor_slices((train_x, train_y))  # 將x,y分成一一對應的元組
    db = db.map(preporocess)  # 執行預處理函數
    db = db.shuffle(60000).batch(20)  # 打亂加分組
    #   測試數據
    db_test = tf.data.Dataset.from_tensor_slices((test_x, test_y))
    db_test = db_test.map(preporocess)
    db_test = db_test.shuffle(10000).batch(10000)
#   操作步驟
#     my_create()
    print('_________________----------------------------____________________________')
    my_load()

 

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM