關於Keras公用數據集的獲取和使用


Keras是Tensorflow2.0的核心高階API,其內置了一些常用的公共數據集,給開發者使用。

 

 以波士頓房價數據集為例,其涵蓋了麻省波士頓的506個不同郊區的房屋數據。有404條訓練數據集和102條測試數據集。

每條數據有14個字段,包含13個屬性和一個房價數據

 

獲取波士頓房價數據集:

1 import tensorflow as tf
2 boston_housing = tf.keras.datasets.boston_housing  #在線加載數據集
3 
4 (train_x,train_y),(test_x,test_y) = boston_housing.load_data() #獲取訓練集和測試機

程序會首先Keras官網下載數據集,然后保存在默認的路徑下面(C:\Users\Administrator.SG-20151030VCPR\.keras\datasets),這個路徑最好不要改,反正數據也不大。

數據拿到了,就順便看看各個屬性和房價之前的關系吧,這里對每個屬性和房價的關系進行可視化:

 1 import tensorflow as tf
 2 import matplotlib.pyplot as plt
 3 boston_housing = tf.keras.datasets.boston_housing  #在線加載數據集
 4 
 5 (train_x,train_y),(_,_) = boston_housing.load_data(test_split=0) #獲取訓練集
 6 
 7 title = ['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS',
 8         'RAD', 'TAX', 'PTRATIO', 'B-1000', 'LSTAT']
 9 plt.figure(figsize = (12,12))                #設置畫布大小為12*12英寸
10 
11 for i in range(len(title)):
12     plt.subplot(4,4,i+1)                    #繪制 4*4 子圖
13     plt.scatter(train_x[:,i], train_y)      #繪制散點圖
14     
15     plt.xlabel(title[i])                         #X軸標簽
16     plt.ylabel("Price($1000)'s")                 #Y軸標簽
17     plt.title(str(i+1)+'.'+title[i]+' - Price')  #設置子圖標題
18     
19 plt.tight_layout()#使標題坐標軸不重疊
20 plt.suptitle('各個屬性與房價的關系', x=0.5, y=1.02, fontsize=20)  #全局標題
21 plt.show()

來看看結果:

 

 

然后就可以使用這些數據來進行后續的數據清洗、模型訓練和結果評價了。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM