注:練習來自於吳恩達機器學習
翻譯后的題目:
你是一個餐廳的老板,你想在其他城市開分店,所以你得到了一些數據(數據在本文最下方),數據中包括不同的城市人口數和該城市帶來的利潤。第一列是城市的人口數,第二列是在這個城市開店所帶來的利潤數。
現在,假設θ0和θ1都是0,計算CostFunction,即計算損失函數
首先,本題線性回歸的公式應該是這樣的:
H(θ) = θ0 + θ1*X
簡單的來說,本題中,θ0和θ1都為0,即求H(θ) = 0的損失值,
然后我們再給出損失的定義:
損失,通俗的來講,即你預測值和給定值的差
這樣就得出了損失函數J(θ)的定義:
m為數據的總條數,即m代表有幾條數據。
第一步,導包
import numpy as np import pandas as pd import matplotlib.pyplot as plt
第二步,把數據讀入,然后把圖打出來看一下:
path = 'ex1data1.txt' data = pd.read_csv(path, header=None, names=['Population', 'Profit']) data.plot(kind='scatter', x='Population', y='Profit', figsize=(12, 8)) plt.show()
圖:
第三步,定義一下costFunction
def computeCost(X, y, theta): inner = np.power(((X * theta.T) - y), 2) return np.sum(inner) / (2 * len(X))
第四步,然后把X從data分出來,Y從data分出來,在X的左邊再加一列1,
分出來后的結果為,X為97行2列,Y為97行1列,θ為1行2列,
costFunction是計算矩陣X*矩陣θ的轉置得到的值來和真實的Y值比較,計算Cost
data.insert(0, 'Ones', 1) rows = data.shape[0] cols = data.shape[1] X = data.iloc[:, 0:cols - 1] Y = data.iloc[:, cols - 1:cols] theta = np.mat('0,0') X = np.mat(X.values) Y = np.mat(Y.values) cost = computeCost(X, Y, theta) print(cost)
標准答案:
32.072733877455676
附數據集ex1data1.txt

6.1101,17.592 5.5277,9.1302 8.5186,13.662 7.0032,11.854 5.8598,6.8233 8.3829,11.886 7.4764,4.3483 8.5781,12 6.4862,6.5987 5.0546,3.8166 5.7107,3.2522 14.164,15.505 5.734,3.1551 8.4084,7.2258 5.6407,0.71618 5.3794,3.5129 6.3654,5.3048 5.1301,0.56077 6.4296,3.6518 7.0708,5.3893 6.1891,3.1386 20.27,21.767 5.4901,4.263 6.3261,5.1875 5.5649,3.0825 18.945,22.638 12.828,13.501 10.957,7.0467 13.176,14.692 22.203,24.147 5.2524,-1.22 6.5894,5.9966 9.2482,12.134 5.8918,1.8495 8.2111,6.5426 7.9334,4.5623 8.0959,4.1164 5.6063,3.3928 12.836,10.117 6.3534,5.4974 5.4069,0.55657 6.8825,3.9115 11.708,5.3854 5.7737,2.4406 7.8247,6.7318 7.0931,1.0463 5.0702,5.1337 5.8014,1.844 11.7,8.0043 5.5416,1.0179 7.5402,6.7504 5.3077,1.8396 7.4239,4.2885 7.6031,4.9981 6.3328,1.4233 6.3589,-1.4211 6.2742,2.4756 5.6397,4.6042 9.3102,3.9624 9.4536,5.4141 8.8254,5.1694 5.1793,-0.74279 21.279,17.929 14.908,12.054 18.959,17.054 7.2182,4.8852 8.2951,5.7442 10.236,7.7754 5.4994,1.0173 20.341,20.992 10.136,6.6799 7.3345,4.0259 6.0062,1.2784 7.2259,3.3411 5.0269,-2.6807 6.5479,0.29678 7.5386,3.8845 5.0365,5.7014 10.274,6.7526 5.1077,2.0576 5.7292,0.47953 5.1884,0.20421 6.3557,0.67861 9.7687,7.5435 6.5159,5.3436 8.5172,4.2415 9.1802,6.7981 6.002,0.92695 5.5204,0.152 5.0594,2.8214 5.7077,1.8451 7.6366,4.2959 5.8707,7.2029 5.3054,1.9869 8.2934,0.14454 13.394,9.0551 5.4369,0.61705