這個代碼寫的好。
模塊主要用到了:os(主要作用文件目錄和路徑)、scipy(圖片讀取、保存、縮放,需要依賴PIL)、numpy、tensorflow、time(計算代碼塊耗時)。
首先判斷是訓練還是測試(即生產),如果是測試的話是測試video還是圖片,測試的圖片是否是彩色的;如果是訓練的話是否要打印輸出詳細的信息(即debug)。debug可以作為切分代碼塊的標記使用。
訓練模型在代碼上可以分為訓練前、訓練中、訓練后。
訓練前:數據的獲取、計算圖的構建(包括網絡結構、損失函數)、計算圖參數的初始化、模型保存對象
- 計算圖網絡結構的重點是卷積核的定義(也就是權重的定義【kernel+bias】)kernel = tf.Variable(tf.truncated_normal(shape, stddev=WEIGHT_INIT_STDDEV), name='kernel'),定義好卷積核之后就可以開始卷積操作了out = tf.nn.conv2d(x_padded, kernel, strides=[1, 1, 1, 1], padding='VALID')。
- 計算圖中損失函數的定義是難點 train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)
訓練中:主要是對EPOCHS-N_BATCHES的迭代,還有驗證
- 先根據損失函數調整網絡參數--->sess.run(train_op, feed_dict={original: original_batch})。
- 然后再計算、存儲和打印各部分的損失-->_ssim_loss, _loss, _p_loss = sess.run([ssim_loss, loss, pixel_loss], feed_dict={original: original_batch})
- 因為每迭代一次就會更新一次權重,事實上每更新一次權重就應該利用更新好的權重狀態對驗證數據集進行一次驗證