TF-slim 模塊是TensorFLow中比較實用的API之一,是一個用於模型構建、訓練、評估復雜模型的輕量化庫。
最近,在使用TF-slim API編寫了一些項目模型后,發現TF-slim模塊在搭建網絡模型時具有相同的編寫模式。這個編寫模式主要包含四個部分:
- __init__():
- build_model():
- fit():
- predict():
1. __init__():
這部分相當於是一個main()函數,其中包含參數的設置,模型整體的連接等操作。具體來說:
a. 設置參數
由於是類的構造函數,所以需要在其中設置一些模型網絡結構的參數、模型訓練時的參數等等。例如
- 學習率
- batch_size
- 訓練代數
- 各種文件的存放地址
- ...
- 對於網絡結構復雜的模型,還可以將網絡結構的table以列表的形式進行保存。便於后續建立模型時可以循環獲取每層的超參數。
1 self.lr = lr 2 self.batch_size = batch_size 3 self.epoch = epoch 4 self.checkpoint_dir_load = checkpoint_dir 5 self.checkpoint_dir = os.path.join(checkpoint_dir, filename + ".ckpt") 6 self.logdir = logdir 7 self.result_dir = result_dir
b. 設置輸入、輸出的占位符placeholder
由於TF-slim框架仍然采用的是tensorflow的那一套,不像tf.keras可以使用keras.layer.Input(),所以還需要使用占位符。例如
1 self.input_image = tf.placeholder(tf.float32, shape=[None, 6000]) 2 self.input_image_raw = tf.reshape(self.input_image, shape=[-1, 6000, 1]) 3 4 self.input_image_label = tf.placeholder(tf.float32, shape=[None, 1, 10]) 5 self.input_label = tf.reshape(self.input_image_label, shape=[-1, 10])
c. 初始化網絡結構,生成訓練輸出和測試輸出
用於后續損失的計算以及優化器的生成,以及訓練結果和測試結果的調用。
此處會涉及到網絡參數的重用,需要使用tf.variable_scope()來管理參數。
1 with tf.variable_scope("Network_Structure") as scope: 2 self.train_digits = self.build_model(is_trained=True) 3 scope.reuse_variables() 4 self.test_digits = self.build_model(is_trained=False)
d. 損失函數和優化器的聲明
此處損失聲明使用的是 輸出的占位符和訓練的輸出。例如:
1 self.loss = slim.losses.softmax_cross_entropy(logits=self.train_digits, onehot_labels=self.input_label, scope="loss") 2 3 self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(loss=self.loss)
e. 最終訓練輸出結果和測試輸出結果的計算
由於網絡輸出的結果不一定是最終的結果。對於多分類問題,需要將one_hot編碼的結果顯示為類值;對於回歸問題,輸出結果可能會需要反歸一化。等等..
如下述代碼,多分類問題的one_hot轉化為類標簽,並進行准確率的計算。
1 # result and accuracy of test 2 self.predicts = tf.math.argmax(self.test_digits, 1) # 將one_hot轉化為類標簽 3 self.test_correction = tf.equal(self.predicts, tf.math.argmax(self.input_label, 1)) 4 self.accuracy = tf.reduce_mean(tf.cast(self.test_correction, "float")) 5 tf.summary.scalar("test_accuracy", self.accuracy) 6 7 # result and accuracy of train 8 self.train_result = tf.math.argmax(self.train_digits, 1) 9 self.train_correlation = tf.equal(self.train_result, tf.math.argmax(self.input_label, 1)) 10 self.train_accuracy = tf.reduce_mean(tf.cast(self.train_correlation, "float")) 11 tf.summary.scalar("train_accuracy", self.accuracy)
2. build_model():【可以是別的名字】
這部分是為了使用tf-slim搭建網絡模型結構。有些模型可能一個函數實現不了,需要多個函數。例如具有共享層的Siamese Network,在共享層后還有其他層。
這一部分也實現了如同tf.keras搭建的模型"樂高式"堆疊,不需要手動為各層生成權重、偏執等參數。也是代碼瘦身的重要環節。
1 with slim.arg_scope([slim.conv1d], padding="SAME", stride=2, activation_fn=tf.nn.relu, 2 weights_initializer=tf.truncated_normal_initializer(stddev=0.01), 3 weights_regularizer=slim.l2_regularizer(0.005) 4 ): 5 net = slim.conv1d(self.input_image_raw, num_outputs=16, kernel_size=8, padding="VALID", scope='conv_1') 6 tf.summary.histogram("conv_1", net) 7 net = slim.conv1d(net, num_outputs=16, kernel_size=8, scope='conv_2') 8 tf.summary.histogram("conv_2", net) 9 def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_3") 10 net = def_max_pool(net) 11 # net = slim.nn.max_pool1d(net, ksize=2, strides=None, padding="VALID", data_format="NWC", name="max_pool_3") 12 tf.summary.histogram("max_pool_3", net) 13 net = slim.conv1d(net, num_outputs=64, kernel_size=4, scope="conv_4") 14 tf.summary.histogram("conv_4", net) 15 net = slim.conv1d(net, num_outputs=64, kernel_size=4, scope="conv_5") 16 tf.summary.histogram("conv_5", net) 17 def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_6") 18 net = def_max_pool(net) 19 # net = slim.nn.max_pool1d(net, ksize=2, strides=1, padding="VALID", name="max_pool_6") 20 tf.summary.histogram("max_pool_6", net) 21 net = slim.conv1d(net, num_outputs=256, kernel_size=4, scope="conv_7") 22 tf.summary.histogram("conv_7", net) 23 net = slim.conv1d(net, num_outputs=256, kernel_size=4, scope="conv_8") 24 tf.summary.histogram("conv_8", net) 25 def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_9") 26 net = def_max_pool(net) 27 # net = slim.nn.max_pool1d(net, ksize=1, strides=1, padding="VALID", name="max_pool_9") 28 tf.summary.histogram("max_pool_9", net) 29 net = slim.conv1d(net, num_outputs=512, kernel_size=2, stride=1, scope="conv_10") 30 tf.summary.histogram("conv_10", net) 31 net = slim.conv1d(net, num_outputs=512, kernel_size=2, stride=1, scope="conv_11") 32 tf.summary.histogram("conv_11", net) 33 def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_12") 34 net = def_max_pool(net) 35 # net = slim.nn.max_pool1d(net, ksize=1, strides=1, padding="VALID", name="max_pool_12") 36 tf.summary.histogram("max_pool_12", net) 37 net = tf.reduce_mean(net, axis=1, name="global_max_pool_13") # 起全局平均池化的作用 38 tf.summary.histogram("global_max_pool_13", net) 39 net = slim.dropout(net, keep_prob=0.5, scope="dropout") 40 tf.summary.histogram("dropout", net) 41 digits = slim.fully_connected(net, num_outputs=num_class, activation_fn=tf.nn.softmax, scope="fully_connected_14") 42 tf.summary.histogram("fully_connected_14", digits) 43 return digits
3. fit():
看名字就知道這一部分需要完成的是訓練部分的代碼。
這一部分需要包含會話的啟動、模型保存器的初始化、循環迭代、batch設置、數據集輸入、輸出數據獲取、喂到網絡中、保存模型、會話關閉等操作。如下述代碼
1 sess = tf.Session() # 啟動會話 2 3 merge_summary_op = tf.summary.merge_all() 4 summary_writer = tf.summary.FileWriter(self.logdir, sess.graph) 5 6 saver = tf.train.Saver(max_to_keep=1) # 生成保存器 7 sess.run(tf.global_variables_initializer()) # 變量激活 8 9 for step in range(self.epoch): # 迭代 10 print("Epoch:%d"%step) 11 avg_cost = 0 12 acc = 0 13 total_batch = int(input_x.shape[0]/self.batch_size) # 划分batch 14 for batch_num in range(total_batch): # batch迭代 15 # 獲取數據 16 batch_xs = input_x[batch_num*self.batch_size:(batch_num+1)*self.batch_size, :] 17 batch_ys = input_y[batch_num*self.batch_size:(batch_num+1)*self.batch_size, :] 18 batch_ys = sess.run(tf.one_hot(batch_ys, depth=10)) 19 # 喂到損失 優化器等等 20 _, loss, acc = sess.run([self.optimizer, self.loss, self.train_accuracy], 21 feed_dict={self.input_image: batch_xs, 22 self.input_image_label: batch_ys}) 23 avg_cost += loss / total_batch 24 acc += acc /total_batch 25 26 summary_str = sess.run(merge_summary_op, feed_dict={self.input_image: batch_xs, 27 self.input_image_label: batch_ys}) 28 summary_writer.add_summary(summary_str, global_step=step) 29 print("Epoch:%d, batch: %d, avg_cost: %g, accuracy: %g" % (step, batch_num, avg_cost, acc)) 30 # 保存模型 31 saver.save(sess, self.checkpoint_dir, global_step=step) 32 sess.close() # 會話關閉
4. predict():
從函數名可以知道這一部分是實現預測部分的代碼。其相對於訓練的過程要更簡單。主要包括會話的啟動、保存器的生成、權重的導入(模型的恢復)、預測、關閉會話。如下述代碼
1 sess = tf.Session() # 會話的啟動 2 3 saver = tf.train.Saver() # 保存器的生成 4 5 module_file = tf.train.latest_checkpoint(self.checkpoint_dir_load) 6 saver.restore(sess, module_file) # 模型的恢復 7 8 input_y = sess.run(tf.one_hot(input_y, depth=10)) # 獲取輸出 9 # 獲取預測結果和預測精度 10 predicts, acc_test = sess.run([self.predicts, self.accuracy], feed_dict={self.input_image: input_x, 11 # 關閉會話 self.input_image_label: input_y}) 12 sess.close() 13 # print("test_accuracy: %f" %acc_test) 14 return predicts, acc_test
上述四步完成后,便可以編寫一個main函數來調用這個類,實現需要的功能。.fit()和.predict()主要是在main()函數來調用。