tensorflow2.0——過擬合優化regularization(簡化參數結構,添加參數代價變量)


參數過多會導致模型過於復雜而出現過擬合現象,通過在loss函數添加關於參數個數的代價變量,限制參數個數,來達到減小過擬合的目的

 

以下是loss公式:

 

 

 代碼多了一個kernel_regularizer參數

  

 

 

 

import tensorflow as tf

def preporocess(x,y):
    x = tf.cast(x,dtype=tf.float32) / 255
    x = tf.reshape(x,(-1,28 *28))                   #   鋪平
    x = tf.squeeze(x,axis=0)
    # print('里面x.shape:',x.shape)
    y = tf.cast(y,dtype=tf.int32)
    y = tf.one_hot(y,depth=10)
    return x,y

def main():
    #   加載手寫數字數據
    mnist = tf.keras.datasets.mnist
    (train_x, train_y), (test_x, test_y) = mnist.load_data()
    #   處理數據
    #   訓練數據
    db = tf.data.Dataset.from_tensor_slices((train_x, train_y))  # 將x,y分成一一對應的元組
    db = db.map(preporocess)  # 執行預處理函數
    db = db.shuffle(60000).batch(2000)  # 打亂加分組
    #   測試數據
    db_test = tf.data.Dataset.from_tensor_slices((test_x, test_y))
    db_test = db_test.map(preporocess)
    db_test = db_test.shuffle(10000).batch(10000)
    #   設置超參
    iter_num = 2000  # 迭代次數
    lr = 0.01  # 學習率
    #   定義模型器和優化器
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(256, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(0.001)),       #   kernel_regularizer是loss上加了關於參數的損失變量
        tf.keras.layers.Dense(128, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(0.001)),
        tf.keras.layers.Dense(64, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(0.001)),
        tf.keras.layers.Dense(32, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(0.001)),
        tf.keras.layers.Dense(10)
    ])
    #   優化器
    # optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)              #   定義優化器
    model.compile(optimizer= optimizer,loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])       #   定義模型配置
    model.fit(db,epochs=30,validation_data=db,validation_freq=2)          #  運行模型,參數validation_data是指在哪個測試集上進行測試
    model.evaluate(db_test)                                                     #   最后打印測試數據相關准確率數據

if __name__ == '__main__':
    main()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM