【tensorflow】神經網絡:斷點續訓


斷點續訓,即在一次訓練結束后,可以先將得到的最優訓練參數保存起來,待到下次訓練時,直接讀取最優參數,在此基礎上繼續訓練。

 

讀取模型參數:

存儲模型參數的文件格式為 ckpt(checkpoint)。

生成 ckpt 文件時,會同步生成索引表,所以可通過判斷是否存在索引表來判斷是否存在模型參數。

# 模型參數保存路徑
checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt"  
if os.path.exists(checkpoint_save_path + ".index"): model.load_weights(checkpoint_save_path)

 

保存模型參數:

# 定義回調函數,在模型訓練時,回調函數會被執行,完成保留參數操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(   # 文件保存路徑
  filepath=checkpoint_save_path,   # 是否只保留模型參數
  save_weights_only=True,   # 是否只保留最優結果
  save_best_only=True ) # 執行訓練過程,保存新的訓練參數
history = model.fit(x_train, y_train,             batch_size=32, epochs=5,             validation_data=(x_test, y_test),             validation_freq=1,             callbacks=[cp_callback])

 

 

代碼:

import tensorflow as tf import os # 讀取輸入特征和標簽
mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() # 數據歸一化,減小計算量,方便神經網絡吸收
x_train, x_test = x_train/255.0, x_test/255.0

# 聲明網絡結構
model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax") ]) # 配置訓練方法
model.compile(optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=[tf.keras.metrics.sparse_categorical_accuracy]) # 如果存在參數文件,直接讀取,在此基礎上繼續訓練
checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt"  # 模型參數保存路徑
if os.path.exists(checkpoint_save_path + ".index"): model.load_weights(checkpoint_save_path) # 定義回調函數,在模型訓練時,完成保留參數操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) # 執行訓練過程,保存新的訓練參數
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) # 打印網絡結構和參數
model.summary()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM