用keras訓練模型並實時顯示loss/acc曲線,(重要的事情說三遍:實時!實時!實時!)實時導出loss/acc數值(導出的方法就是實時把loss/acc等寫到一個文本文件中,其他模塊如前端調用時可直接讀取文本文件),同時也涉及了plt畫圖方法
ps:以下代碼基於網上的一段程序修改完成,如有侵權,請聯系我哈!
上代碼:
from keras import Sequential, initializers, optimizers from keras.layers import Activation, Dense import numpy as np import pylab as pl from IPython import display from keras.callbacks import Callback from keras.datasets import mnist import keras from keras.layers import Conv2D, MaxPooling2D from keras.layers import Dense, Dropout, Flatten #定義回調函數的類,用於實時顯示loss/acc曲線和導出loss/acc數值 class DrawCallback(Callback): def __init__(self, runtime_plot=True): # 初始化 self.init_loss = None self.init_val_loss = None self.init_acc = None self.init_val_acc = None self.runtime_plot = runtime_plot self.xdata = [] self.ydata = [] self.ydata2 = [] self.ydata3 = [] self.ydata4 = [] def _plot(self, epoch=None): epochs = self.params.get("epochs") pl.subplot(121) #畫第一個圖,121表示縱向1個圖,橫向2個圖,當前第1個圖 pl.ylim(0, int(self.init_loss*1.2)) #限制坐標軸范圍 pl.xlim(0, epochs) pl.plot(self.xdata, self.ydata,'r', label='loss') #xdata/ydata均為不斷增長的一維數組,同時定義了線段顏色/類型/圖例 pl.plot(self.xdata, self.ydata2, 'b--', label='val_loss') pl.xlabel('Epoch {}/{}'.format(epoch or epochs, epochs)) #坐標軸顯示變化的標簽 pl.ylabel('Loss {:.4f}'.format(self.ydata[-1])) pl.legend() #顯示圖例,不加這個即便是定義圖例了也沒用 pl.title('loss') #顯示標題 pl.subplot(122) pl.ylim(0, 1.2) pl.xlim(0, epochs) pl.plot(self.xdata, self.ydata3,'r', label='acc') pl.plot(self.xdata, self.ydata4, 'b--', label='val_acc') pl.xlabel('Epoch {}/{}'.format(epoch or epochs, epochs)) pl.ylabel('Loss {:.4f}'.format(self.ydata[-1])) pl.legend() pl.title('acc') def _runtime_plot(self, epoch): self._plot(epoch) #不斷的清圖 display.clear_output(wait=True) display.display(pl.gcf()) pl.gcf().clear() def plot(self): self._plot() pl.show() #顯示窗口 def on_epoch_end(self, epoch, logs = None): #更新xdata/ydata logs = logs or {} # batch_size = self.params.get("batch_size") epochs = self.params.get("epochs") #獲取訓練相關數據 loss = logs.get("loss") val_loss = logs.get("val_loss") acc = logs.get("acc") val_acc = logs.get("val_acc") epochs_str = str(epochs)[0:6] #為了寫入txt,必須轉為字符型,為了美觀只保留小數點后4位 loss_str = str(loss)[0:6] val_loss_str = str(val_loss)[0:6] acc_str = str(acc)[0:6] val_acc_str = str(val_acc)[0:6] f = open('logs_r/record.txt','a') #要用追加方式‘a’寫入txt,所在行數就是當前迭代次數 f.write('epochs:{}_loss:{}_val_loss:{}_acc:{}_val_acc{}'.format(epochs_str,loss_str,val_loss_str,acc_str,val_acc_str)) f.write('\n') f.close() if self.init_loss is None: #增加xdata/ydata內容 self.init_loss = loss self.init_val_loss = val_loss self.xdata.append(epoch) self.ydata.append(loss) self.ydata2.append(val_loss) self.ydata3.append(acc) self.ydata4.append(val_acc) if self.runtime_plot: self._runtime_plot(epoch) # 下面開始構建keras需要的東西 def viz_keras_fit(runtime_plot=False): d = DrawCallback(runtime_plot = runtime_plot) #實例化回調函數 (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.reshape(-1,28,28,1) x_test = x_test.reshape(-1,28,28,1) input_shape = (28,28,1) x_train = x_train/255 x_test = x_test/255 y_train = keras.utils.to_categorical(y_train,10) y_test = keras.utils.to_categorical(y_test,10) #為了減小計算量,減少了訓練/測試數據 x_train = x_train[0:600,:,:,:] x_test = x_test[0:100,:,:,:] y_train = y_train[0:600,:] y_test = y_test[0:100,:] model = Sequential() #實例化一個模型 #接下來一頓操作,就是搭建網絡 model.add(Conv2D(filters=32, kernel_size=(3,3), activation='relu', input_shape=input_shape, name='conv1')) model.add(Conv2D(64,(3,3),activation='relu',name='conv2')) model.add(MaxPooling2D(pool_size=(2,2),name='pool2')) model.add(Dropout(0.25,name='dropout1')) model.add(Flatten(name='flat1')) model.add(Dense(128,activation='relu')) model.add(Dropout(0.5,name='dropout2')) model.add(Dense(10,activation='softmax',name='output')) #編譯網絡,同時定義了loss方法/優化方法/監測內容 model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adadelta(), metrics=['accuracy']) #開始訓練 model.fit(x = x_train, y = y_train, epochs=30, verbose=0, #當值為1時,會打印訓練過程 validation_data=(x_test, y_test), #加入測試數據,不然有些數據時看不到的 callbacks=[d]) #指定回調函數 return d
最后運行:
viz_keras_fit(runtime_plot=True) #調用函數
顯示結果: