TensorFlow slim(二) 使用TF-slim編程模板(一)


  TF-slim 模塊是TensorFLow中比較實用的API之一,是一個用於模型構建、訓練、評估復雜模型的輕量化庫

  最近,在使用TF-slim API編寫了一些項目模型后,發現TF-slim模塊在搭建網絡模型時具有相同的編寫模式。這個編寫模式主要包含四個部分:

  • __init__():
  • build_model():
  • fit():
  • predict():

1. __init__():

  這部分相當於是一個main()函數,其中包含參數的設置,模型整體的連接等操作。具體來說:

  a. 設置參數

  由於是類的構造函數,所以需要在其中設置一些模型網絡結構的參數、模型訓練時的參數等等。例如

  • 學習率
  • batch_size
  • 訓練代數
  • 各種文件的存放地址
  • ...
  • 對於網絡結構復雜的模型,還可以將網絡結構的table以列表的形式進行保存。便於后續建立模型時可以循環獲取每層的超參數。
1 self.lr = lr
2 self.batch_size = batch_size
3 self.epoch = epoch
4 self.checkpoint_dir_load = checkpoint_dir
5 self.checkpoint_dir = os.path.join(checkpoint_dir, filename + ".ckpt")
6 self.logdir = logdir
7 self.result_dir = result_dir

  b. 設置輸入、輸出的占位符placeholder

  由於TF-slim框架仍然采用的是tensorflow的那一套,不像tf.keras可以使用keras.layer.Input(),所以還需要使用占位符。例如

1 self.input_image = tf.placeholder(tf.float32, shape=[None, 6000])
2 self.input_image_raw = tf.reshape(self.input_image, shape=[-1, 6000, 1])
3 
4 self.input_image_label = tf.placeholder(tf.float32, shape=[None, 1, 10])
5 self.input_label = tf.reshape(self.input_image_label, shape=[-1, 10])

  c. 初始化網絡結構,生成訓練輸出和測試輸出

  用於后續損失的計算以及優化器的生成,以及訓練結果和測試結果的調用。

  此處會涉及到網絡參數的重用,需要使用tf.variable_scope()來管理參數。

1 with tf.variable_scope("Network_Structure") as scope:
2     self.train_digits = self.build_model(is_trained=True)
3     scope.reuse_variables()
4     self.test_digits = self.build_model(is_trained=False)

  d. 損失函數和優化器的聲明

  此處損失聲明使用的是 輸出的占位符和訓練的輸出。例如:

1 self.loss = slim.losses.softmax_cross_entropy(logits=self.train_digits, onehot_labels=self.input_label, scope="loss")
2 
3 self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(loss=self.loss)

  e. 最終訓練輸出結果和測試輸出結果的計算

  由於網絡輸出的結果不一定是最終的結果。對於多分類問題,需要將one_hot編碼的結果顯示為類值;對於回歸問題,輸出結果可能會需要反歸一化。等等..

  如下述代碼,多分類問題的one_hot轉化為類標簽,並進行准確率的計算。

 1 # result and accuracy of test
 2 self.predicts = tf.math.argmax(self.test_digits, 1)   # 將one_hot轉化為類標簽
 3 self.test_correction = tf.equal(self.predicts, tf.math.argmax(self.input_label, 1))
 4 self.accuracy = tf.reduce_mean(tf.cast(self.test_correction, "float"))
 5 tf.summary.scalar("test_accuracy", self.accuracy)
 6 
 7 # result and accuracy of train
 8 self.train_result = tf.math.argmax(self.train_digits, 1)
 9 self.train_correlation = tf.equal(self.train_result, tf.math.argmax(self.input_label, 1))
10 self.train_accuracy = tf.reduce_mean(tf.cast(self.train_correlation, "float"))
11 tf.summary.scalar("train_accuracy", self.accuracy)

2. build_model():【可以是別的名字】

  這部分是為了使用tf-slim搭建網絡模型結構。有些模型可能一個函數實現不了,需要多個函數。例如具有共享層的Siamese Network,在共享層后還有其他層。

  這一部分也實現了如同tf.keras搭建的模型"樂高式"堆疊,不需要手動為各層生成權重、偏執等參數。也是代碼瘦身的重要環節。

 1 with slim.arg_scope([slim.conv1d], padding="SAME", stride=2, activation_fn=tf.nn.relu,
 2                     weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
 3                     weights_regularizer=slim.l2_regularizer(0.005)
 4                     ):
 5     net = slim.conv1d(self.input_image_raw, num_outputs=16, kernel_size=8, padding="VALID", scope='conv_1')
 6     tf.summary.histogram("conv_1", net)
 7     net = slim.conv1d(net, num_outputs=16, kernel_size=8, scope='conv_2')
 8     tf.summary.histogram("conv_2", net)
 9     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_3")
10     net = def_max_pool(net)
11     # net = slim.nn.max_pool1d(net, ksize=2, strides=None, padding="VALID", data_format="NWC", name="max_pool_3")
12     tf.summary.histogram("max_pool_3", net)
13     net = slim.conv1d(net, num_outputs=64, kernel_size=4, scope="conv_4")
14     tf.summary.histogram("conv_4", net)
15     net = slim.conv1d(net, num_outputs=64, kernel_size=4, scope="conv_5")
16     tf.summary.histogram("conv_5", net)
17     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_6")
18     net = def_max_pool(net)
19     # net = slim.nn.max_pool1d(net, ksize=2, strides=1, padding="VALID", name="max_pool_6")
20     tf.summary.histogram("max_pool_6", net)
21     net = slim.conv1d(net, num_outputs=256, kernel_size=4, scope="conv_7")
22     tf.summary.histogram("conv_7", net)
23     net = slim.conv1d(net, num_outputs=256, kernel_size=4, scope="conv_8")
24     tf.summary.histogram("conv_8", net)
25     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_9")
26     net = def_max_pool(net)
27     # net = slim.nn.max_pool1d(net, ksize=1, strides=1, padding="VALID", name="max_pool_9")
28     tf.summary.histogram("max_pool_9", net)
29     net = slim.conv1d(net, num_outputs=512, kernel_size=2, stride=1, scope="conv_10")
30     tf.summary.histogram("conv_10", net)
31     net = slim.conv1d(net, num_outputs=512, kernel_size=2, stride=1, scope="conv_11")
32     tf.summary.histogram("conv_11", net)
33     def_max_pool = tf.layers.MaxPooling1D(pool_size=2, strides=2, padding="VALID", name="max_pool_12")
34     net = def_max_pool(net)
35     # net = slim.nn.max_pool1d(net, ksize=1, strides=1, padding="VALID", name="max_pool_12")
36     tf.summary.histogram("max_pool_12", net)
37     net = tf.reduce_mean(net, axis=1, name="global_max_pool_13")   # 起全局平均池化的作用
38     tf.summary.histogram("global_max_pool_13", net)
39     net = slim.dropout(net, keep_prob=0.5, scope="dropout")
40     tf.summary.histogram("dropout", net)
41     digits = slim.fully_connected(net, num_outputs=num_class, activation_fn=tf.nn.softmax, scope="fully_connected_14")
42     tf.summary.histogram("fully_connected_14", digits)
43 return digits

3. fit():

  看名字就知道這一部分需要完成的是訓練部分的代碼

  這一部分需要包含會話的啟動、模型保存器的初始化、循環迭代、batch設置、數據集輸入、輸出數據獲取、喂到網絡中、保存模型、會話關閉等操作。如下述代碼

 1 sess = tf.Session()  # 啟動會話
 2 
 3 merge_summary_op = tf.summary.merge_all()
 4 summary_writer = tf.summary.FileWriter(self.logdir, sess.graph)
 5 
 6 saver = tf.train.Saver(max_to_keep=1)  # 生成保存器
 7 sess.run(tf.global_variables_initializer())   # 變量激活
 8 
 9 for step in range(self.epoch):    # 迭代
10     print("Epoch:%d"%step)
11     avg_cost = 0
12     acc = 0
13     total_batch = int(input_x.shape[0]/self.batch_size)   # 划分batch
14     for batch_num in range(total_batch):   # batch迭代
15         # 獲取數據
16         batch_xs = input_x[batch_num*self.batch_size:(batch_num+1)*self.batch_size, :]
17         batch_ys = input_y[batch_num*self.batch_size:(batch_num+1)*self.batch_size, :]
18         batch_ys = sess.run(tf.one_hot(batch_ys, depth=10))
19         # 喂到損失 優化器等等
20         _, loss, acc = sess.run([self.optimizer, self.loss, self.train_accuracy],
21                                                         feed_dict={self.input_image: batch_xs,
22                                                          self.input_image_label: batch_ys})
23         avg_cost += loss / total_batch
24         acc += acc /total_batch
25 
26         summary_str = sess.run(merge_summary_op, feed_dict={self.input_image: batch_xs,
27                                                             self.input_image_label: batch_ys})
28         summary_writer.add_summary(summary_str, global_step=step)
29         print("Epoch:%d, batch: %d, avg_cost: %g, accuracy: %g" % (step, batch_num, avg_cost, acc))
30     # 保存模型
31     saver.save(sess, self.checkpoint_dir, global_step=step)
32 sess.close()   # 會話關閉

4. predict():

  從函數名可以知道這一部分是實現預測部分的代碼。其相對於訓練的過程要更簡單。主要包括會話的啟動、保存器的生成、權重的導入(模型的恢復)、預測、關閉會話。如下述代碼

 1 sess = tf.Session()   # 會話的啟動
 2 
 3 saver = tf.train.Saver()  # 保存器的生成
 4 
 5 module_file = tf.train.latest_checkpoint(self.checkpoint_dir_load)
 6 saver.restore(sess, module_file)    # 模型的恢復
 7 
 8 input_y = sess.run(tf.one_hot(input_y, depth=10))  # 獲取輸出
 9 # 獲取預測結果和預測精度
10 predicts, acc_test = sess.run([self.predicts, self.accuracy], feed_dict={self.input_image: input_x,
11 # 關閉會話                                                                            self.input_image_label: input_y})
12 sess.close()
13 # print("test_accuracy: %f" %acc_test)
14 return predicts, acc_test

  上述四步完成后,便可以編寫一個main函數來調用這個類,實現需要的功能。.fit()和.predict()主要是在main()函數來調用。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM