波士頓房價預測
導入模塊
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties
from sklearn.linear_model import LinearRegression
%matplotlib inline
font = FontProperties(fname='/Library/Fonts/Heiti.ttc')
獲取數據
housing-data.txt
文件可以加我微信獲取:a1171958281
打印數據
df = pd.read_csv('housing-data.txt', sep='\s+', header=0)
df.head()
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | MEDV | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.00632 | 18.0 | 2.31 | 0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1 | 296.0 | 15.3 | 396.90 | 4.98 | 24.0 |
1 | 0.02731 | 0.0 | 7.07 | 0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2 | 242.0 | 17.8 | 396.90 | 9.14 | 21.6 |
2 | 0.02729 | 0.0 | 7.07 | 0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2 | 242.0 | 17.8 | 392.83 | 4.03 | 34.7 |
3 | 0.03237 | 0.0 | 2.18 | 0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3 | 222.0 | 18.7 | 394.63 | 2.94 | 33.4 |
4 | 0.06905 | 0.0 | 2.18 | 0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3 | 222.0 | 18.7 | 396.90 | 5.33 | 36.2 |
特征選擇
散點圖矩陣
使用sns庫的pairplot()方法繪制的散點圖矩陣可以查看數據集內部特征之間的關系,例如可以觀察到特征間分布關系以及離群樣本。
本文只繪制了三列(RM、MEDV(標記)、LSTAT)特征和標記之間的聯系,有興趣的可以調用該方法查看其它特征之間的關系。
# 選擇三列特征
cols = ['RM', 'MEDV', 'LSTAT']
# 構造三列特征之間的聯系即構造散點圖矩陣
sns.pairplot(df[cols], height=3)
plt.tight_layout()
plt.show()
上圖可以看出第一行(RM)第二列(MEDV)的特征與標記存在線性關系;第二行(MEDV)第二列(MEDV)即MEDV值可能呈正態分布。
關聯矩陣
使用sns.heatmap()方法繪制的關聯矩陣可以看出特征之間的相關性大小,關聯矩陣是包含皮爾森積矩相關系數的正方形矩陣,用來度量特征對之間的線性依賴關系。
# 求解上述三列特征的相關系數
'''
對於一般的矩陣X,執行A=corrcoef(X)后,A中每個值的所在行a和列b,反應的是原矩陣X中相應的第a個列向量和第b個列向量的
相似程度(即相關系數)
'''
cm = np.corrcoef(df[cols].values.T)
# 控制顏色刻度即顏色深淺
sns.set(font_scale=2)
# 構造關聯矩陣
hm = sns.heatmap(cm, cbar=True, annot=True, square=True, fmt='.2f', annot_kws={
'size': 20}, yticklabels=cols, xticklabels=cols)
plt.show()
上圖可以看出特征LSTAT和標記MEDV的具有最高的相關性-0.74,但是在散點圖矩陣中會發現LSTAT和MEDV之間存在着明顯的非線性關系;而特征RM和標記MEDV也具有較高的相關性0.70,並且從散點矩陣中會發現特征RM和標記MEDV之間存在着線性關系。因此接下來將使用RM作為線性回歸模型的特征。
訓練模型
X = df[['RM']].values
y = df['MEDV'].values
lr = LinearRegression()
lr.fit(X, y)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False)
可視化
plt.scatter(X, y, c='r', s=30, edgecolor='white',label='訓練數據')
plt.plot(X, lr.predict(X), c='g')
plt.xlabel('平均房間數目[MEDV]', fontproperties=font)
plt.ylabel('以1000美元為計價單位的房價[RM]', fontproperties=font)
plt.title('波士頓房價預測', fontproperties=font, fontsize=20)
plt.legend(prop=font)
plt.show()
print('普通線性回歸斜率:{}'.format(lr.coef_[0]))
普通線性回歸斜率:9.10210898118031
使用RANSAC算法之后可以發現線性回歸擬合的線與未用RANSAC算法擬合出來的線的斜率不同,可以說RANSAC算法降低了離群值潛在的影響,但是這並不能說明這種方法對未來新數據的預測性能是否有良性影響。