使用tf.keras API 構建神經網絡(基礎)


tf2.0推薦的模型搭建方法是:

  1. 繼承tf.keras.Model類,進行擴展以定義自己的新模型。
  2. 手工編寫模型訓練、評估模型的流程。

    (優點:靈活度高;與其他深度學習框架共通)

 

以CNN處理單通道圖片作為示例:

class CNN(tf.keras.Model):
    def __init__(self): #定義類的構造方法(這里是初始化預定義好的網絡結構)
        super().__init__() #這個類是繼承tf.keras.Model類,因此執行父類的初始化
        self.conv1 = tf.keras.layers.Conv2D(
            filters=32,             # 卷積層神經元(卷積核)數目
            kernel_size=[5, 5],     # 感受野大小
            padding='same',         # padding策略(vaild 或 same)
            activation=tf.nn.relu   # 激活函數
        )
        self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
        self.conv2 = tf.keras.layers.Conv2D(
            filters=64,
            kernel_size=[5, 5],
            padding='same',
            activation=tf.nn.relu
        )
        self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
        self.flatten = tf.keras.layers.Reshape(target_shape=(7 * 7 * 64,))
        self.dense1 = tf.keras.layers.Dense(units=1024, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=10)
 
    def call(self, inputs):
        x = self.conv1(inputs)                  # [batch_size, 28, 28, 32]
        x = self.pool1(x)                       # [batch_size, 14, 14, 32]
        x = self.conv2(x)                       # [batch_size, 14, 14, 64]
        x = self.pool2(x)                       # [batch_size, 7, 7, 64]
        x = self.flatten(x)                     # [batch_size, 7 * 7 * 64]
        x = self.dense1(x)                      # [batch_size, 1024]
        x = self.dense2(x)                      # [batch_size, 10]
        output = tf.nn.softmax(x)
        return output

下面解釋一下這種網絡構建方法:

  1. 我們定義了一個類CNN來繼承tf.keras.Model類,目的是為了相較於原類能夠有更多自定義的方法,更靈活
  2. 自定義的類中,首先在__init__中定義類的構造方法。構造方法中我們定義了模型中的各個層、以及對各個層的參數賦值(將tf.keras.layers中包裝的‘層’實例化)。(建議定義的順序按照設計的CNN網絡架構的順序排列,便於理解)
  3. 定義一個call方法,一個類只要實現了call方法,這個類的實例就可以用函數一樣的形式進行調用,如CNN_obj = CNN(); CNN_obj()這種形式,並可以向其傳遞參數。
  4. 在我們自定義的類中,call方法要接受訓練數據的特征,特征在定義的層中順序傳遞,最后輸出預測值,用於后續計算。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM