Tensorflow之單變量線性回歸問題的解決方法


跟着網易雲課堂上面的免費公開課深度學習應用開發Tensorflow實踐學習,學到線性回歸這里感覺有很多需要總結,梳理記錄下階段性學習內容。

題目:通過生成人工數據集合,基於TensorFlow實現y=2*x+1線性回歸

使用Tensorflow進行算法設計與訓練的核心步驟

(1)准備數據

(2)構建模型

(3)訓練模型

(4)進行預測

#線性回歸問題

#******************一、准備數據:**********************

#生成人工數據集

# 在Jupter中,使用matplotlib顯示圖像需要設置為inline模式,否則不會顯示圖像
%matplotlib inline
 
import matplotlib.pyplot as plt #載入matplotlib,用於繪圖
import numpy as np #載入numpy,numpy是Python進行科學計算時的基礎模塊
import tensorflow as tf #載入Tensorflow
 
#設置隨機種子。訓練之后結果隨機,隨機種子起到固定初始值的作用,為了訓練之后得到一樣的結果
np.random.seed(5)
#直接采用np生成等差數列的方法,生成100個點,每個點的取值在-1~1之間
x_data = np.linspace(-1,1,100)

# y = 2x +1 + 噪聲,其中,噪聲的維度與x_data一致
y_data = 2 * x_data + 1.0 + np.random.randn(*x_data.shape) * 0.4

#***********************二、構建線性模型*************************

#定義訓練數據的占位符,x是特征,y是標簽值
x = tf.placeholder("float",name= "x")
y = tf.placeholder("float",name = "y")

#定義模型函數
def model(x,w,b):
    return tf.multiply(x,w) + b

#定義模型結構
#Tensorflow變量的聲明函數是tf.Variable。tf.Variable的作用是保存和更新函數,變量的初始值可以是隨機數、常數,或是通過其他變量的初始值計算得到
#構建線性函數的斜率,變量w
w = tf.Variable(1.0,name = "w0")
#構建線性函數的截距,變量b
b = tf.Variable(0.0,name = "b0")

#pred是預測值,前向計算
pred = model(x,w,b)

#************************三、訓練模型*******************************
#設置訓練參數
#迭代次數(訓練輪數)
train_epochs = 10

#學習率
learning_rate = 0.05

#定義優化器、最小損失函數

#定義損失函數,損失函數用於描述預測值與真實值之間的差別,從而指導模型收斂方向。常見損失函數:均方差、交叉熵
#采用均方差作為損失函數
loss_function = tf.reduce_mean(tf.square(y-pred))

#定義優化器
#定義優化器Optimizer,初始化一個GradientDescentOptimizer(梯度下降優化器)
#設置學習率和優化目標:最小化損失
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)

#創建會話
#聲明會話
sess = tf.Session()
#變量初始化
#在真正執行計算前,需要將所有變量初始化。通過tf.global_variables_initializer函數可實現對所有變量的初始化
init = tf.global_variables_initializer()
sess.run(init)

#迭代訓練
#模型訓練階段,設置迭代輪次,每次通過將樣本逐個輸入模型,進行梯度下降優化操作。每輪迭代后,繪制出模型曲線
#開始訓練,輪次為epoch,采用SGD隨機梯度下降優化方法
for epoch in range(train_epochs):
    for xs,ys in zip(x_data,y_data):
        _,loss = sess.run([optimizer,loss_function],feed_dict={x:xs,y:ys})
    b0temp = b.eval(session=sess)
    w0temp = w.eval(session=sess)
    plt.plot(x_data,w0temp * x_data + b0temp)  #畫圖
    
#結果查看。當訓練完成后,打印查看參數。數據每次運行都可能會有所不同
print("w:",sess.run(w))     #w的值應該在2附近
print("b:",sess.run(b))     #b的值應該在1附近

#結果可視化
plt.scatter(x_data,y_data,label='Original data')
plt.plot(x_data,x_data*sess.run(w) + sess.run(b),label='Fitted line',color='r',linewidth=3)
plt.legend(loc=2) #通過參數loc指定圖例位置

#*********************四、利用學習到的模型進行預測*******************

x_test = 3.21

predict = sess.run(pred,feed_dict={x:x_test})
print("預測值: %f"%predict)

target = 2 * x_test + 1.0
print("目標值: %f"%target)

  題目二:通過生成人工數據集合,基於TensorFlow實現y=3.1234*x+2.98線性回歸

# 在Jupter中,使用matplotlib顯示圖像需要設置為inline模式,否則不會顯示圖像
%matplotlib inline
 
import matplotlib.pyplot as plt #載入matplotlib
import numpy as np #載入numpy
import tensorflow as tf #載入Tensorflow
 
#設置隨機種子
np.random.seed(5)
#直接采用np生成等差數列的方法,生成100個點,每個點的取值在-1~1之間
x_data = np.linspace(-1,1,100)
# y = 3.1234x +2.98 + 噪聲, 其中, 噪聲的唯度與x_data一致
y_data = 3.1234*x_data + 2.98 + np.random.randn(*x_data.shape)*0.4
x = tf.placeholder("float",name = "x")
y = tf.placeholder("float",name = "y")
 
def model(x,w,b):
    return tf.multiply(x,w)+b
# 構建線性函數的斜率, 變量w
w = tf.Variable(1.0,name="w")
# 構建線性函數的截距,變量b
b = tf.Variable(0.0, name="b0")
#pred是預測值,前向計算
pred = model(x,w,b)
 
# 迭代次數(訓練輪數)
train_epochs = 10
# 學習率
learning_rate = 0.05
# 采用均方差作為損失函數
loss_function = tf.reduce_mean(tf.square(y-pred))
# 梯度下降優化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
 
# 開始訓練,輪數為 epoch,采用SGD隨機梯度下降優化方法
#zip為組裝,x,y都為一維數組. zip 把x,y組裝起來也為一維數組,每個單元為(x,y)
 
for epoch in range(train_epochs):
    for xs,ys in zip(x_data, y_data):
        #優化器給了一個下划線,loss_function 給了loss
        _, loss=sess.run([optimizer,loss_function],feed_dict={x: xs, y:ys})
 
plt.scatter(x_data,y_data,label='Original data')
plt.plot(x_data,x_data*sess.run(w)+sess.run(b),\
        label='Fitted line',color='r',linewidth=3)
plt.legend(loc=2)#通過參數loc指定圖例位置
 
print("w: ", sess.run(w)) #w的值應該在3.1234附近
print("b: ",sess.run(b)) #b的值應該在2.98附近

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM