tensorflow2實現線性回歸例子


%tensorflow_version 2.x
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow import initializers as init
from tensorflow import losses
from tensorflow.keras import optimizers
from tensorflow import data as tfdata

#1.生成數據
num_inputs = 2#數據有兩個特征
num_examples = 1000#共有1000條數據
true_w = [2, -3.4]#兩個特征的權重
true_b = 4.2#偏置
features = tf.random.normal(shape=(num_examples, num_inputs), stddev=1)#隨機生成一個1000*2的矩陣,每行代表一條數據
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b#計算y值
labels += tf.random.normal(labels.shape, stddev=0.01)#加上一個偏差

#2.組合數據
batch_size = 10
# 將訓練數據的特征和標簽組合
dataset = tfdata.Dataset.from_tensor_slices((features, labels))#按第0維進行切分,和標簽組合
# 隨機讀取小批量
dataset = dataset.shuffle(buffer_size=num_examples)#隨機打亂1000
dataset = dataset.batch(batch_size)
data_iter = iter(dataset)#生成一個迭代器

輸出一個batch看一下:

這里是其中一個batch,它包含10條原數據。

 

 

model = keras.Sequential()#定義模型
model.add(layers.Dense(1, kernel_initializer=init.RandomNormal(stddev=0.01)))#定義網絡層

loss = losses.MeanSquaredError()#定義損失
trainer = optimizers.SGD(learning_rate=0.03)#定義優化器為隨機梯度下降

loss_history = []
num_epochs = 3
for epoch in range(1, num_epochs + 1):#全體數據循環三次
    for (batch, (X, y)) in enumerate(dataset):#對每一個batch循環
        with tf.GradientTape() as tape:#定義梯度
            l = loss(model(X, training=True), y)
        loss_history.append(l.numpy().mean())#記錄該batch的損失
        grads = tape.gradient(l, model.trainable_variables)#tape.gradient找到變量的梯度
        trainer.apply_gradients(zip(grads, model.trainable_variables))#更新權重
    
    l = loss(model(features), labels)#遍歷完一次全體數據后的損失
    print('epoch %d, loss: %f' % (epoch, l))

因為我們要求循環所有數據3次,而每一次循環都是小批量循環,每個小批量里都有10條數據,所以首先寫出兩個for循環,最里層的循環是每次循環10條數據。

我們通過調用tensorflow.GradientTape記錄動態圖梯度,之前定義的損失函數是均方誤差,需要真實值和模型值,於是把model(X)和y輸入loss里。

我們可以記錄每個batch的損失,添加到loss_history中。

通過 model.trainable_variables 找到需要更新的變量,並用 trainer.apply_gradients 更新權重,完成一步訓練。

查看訓練出來的參數和原參數的對比:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM