python——sklearn完整例子整理示范(有監督,邏輯回歸范例)(原創)


sklearn使用方法,包括從制作數據集,拆分數據集,調用模型,保存加載模型,分析結果,可視化結果

 1 import pandas as pd
 2 import numpy as np
 3 from sklearn.model_selection import train_test_split #訓練測試集拆分
 4 from sklearn.linear_model import LogisticRegression  #邏輯回歸模型
 5 import matplotlib.pyplot as plt #畫圖函數
 6 
 7 from sklearn.externals import joblib #保存加載模型函數joblib
 8 
 9 #以下為sklearn評測指標的一些函數
10 from sklearn.metrics import precision_score
11 from sklearn.metrics import classification_report
12 from sklearn.metrics import confusion_matrix
13 
14 #1. 若有文件,建議用read_csv加載,用sep代表按照該符號分割,若文件無列標簽名,則header設置為None,自定義標簽名names
15 
16 #file = "XXX_file"
17 #df = pd.read_csv(file, sep='###',header = None, names = ['flag','uuid','features'],engine = 'python')
18 #df.head()
19 
20 
21 #2. 准備好特征集合x 和 標簽集合y
22 
23 #x = df['features']  #x存儲特征
24 #y = df['flag']      #y存儲標簽
25 x = np.random.rand(100,3)
26 print("x:\n",x)
27 print(x.shape)
28 y = np.array([1 if i.sum()>1.2 else 0 for i in x])  #若三個維度之和大於1.2,則y分類為1,否則為0
29 print("y:\n",y)
30 print(y.shape)   #注意y的形式必須是(n,),即numpy中的一維格式
31 #當同時有 if 和 else 時,列表生成式構造為 [最終表達式 - 條件分支判斷 - 范圍選擇]
32 
33 
34 #3. 拆分訓練集和測試集(7:3)
35 x_train, x_test, y_train, y_test = train_test_split(x,y, random_state=666, train_size = 0.7)
36 
37 
38 #4. 生成模型,並喂入數據
39 clf = LogisticRegression()
40 clf.fit(x_train, y_train)
41 
42 
43 #5. 保存模型(用joblib,不用pickle)
44 joblib.dump(clf,"lr.model")    #from sklearn.externals import joblib
45 #加載模型是: clf = joblib.load("lr.model")
46 
47 
48 #6. 預測結果,並評測
49 y_pred = clf.predict(x_test)  #預測出來的值計做y_pred
50 y_true = y_test               #真實值計做y_true,和sklearn參數一模一樣
51 
52 target_names = ['class 0', 'class 1']
53 print(classification_report(y_true, y_pred, target_names=target_names)) #可以參考sklearn官網API
54 print(confusion_matrix(y_true, y_pred)) #混淆矩陣(記住!sklearn定義的混淆矩陣m行n列含義是:該樣本真實值是m,預測值是n)
55 print("precision_score:", precision_score(y_test,y_pred)) #打印精確率(記住!默認是positive,即標注為1的精確率)
56 
57 
58 #7. 附加:結果可視化,利用plt(用seaborn也可以)
59 """
60 #神秘代碼,主要是保證plt字體顯示正確
61 plt.rcParams['font.sans-serif'] = ['SimHei']   
62 plt.rcParams['font.family']='sans-serif' 
63 plt.rcParams['axes.unicode_minus'] = False
64 """
65 plt.plot(y_pred,"b.", label = "y_pred")   #blue,點號
66 plt.plot(y_true,"r*", label = "y_true")   #red,星號
67 plt.legend()
68 plt.show()  #畫的比較簡略,可以進一步美化

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM