Python鳶尾花分類實現


#coding:utf-8

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import matplotlib.pyplot as plt

iris_dataset = load_iris() # 獲取數據
# print("keys of iris_dataset:\n{}".format(iris_dataset.keys()))
# print(iris_dataset["DESCR"][:193]+"\n...")
# print("target names:{}".format(iris_dataset["target_names"]))
# print("feature names:{}".format(iris_dataset["feature_names"]))
# print(iris_dataset["data"][:5])
# print(iris_dataset["data"], iris_dataset["target"])
# 對數據進行拆分,分為訓練數據和測試數據
x_train, x_test, y_train, y_test = train_test_split(iris_dataset["data"], iris_dataset["target"], random_state=0)
# print(x_train, x_test, y_train, y_test)

knn = KNeighborsClassifier(n_neighbors=1) # 獲取KNN對象
knn.fit(x_train, y_train) # 訓練模型

# 評估模型
y_pre = knn.predict(x_test)
score = knn.score(x_test, y_test) # 調用打分函數
print("test set predictions:\n{}".format(y_test))
print("test set score:{:.2f}".format(score))
if score > 0.9:
x_new = np.array([[5, 2.9, 1, 0.3]])
print("x_new.shape:{}".format(x_new.shape))
prediction = knn.predict(x_new) # 預測
print("prediction:{}".format(prediction))
print("predicted target name:{}".format(iris_dataset["target_names"][prediction]))

# 可視化展示
plt.title("KNN Classification")
plt.plot(x_train, y_train, "b.") # 訓練數據打點
plt.plot(x_test, y_test, "y.") # 測試數據打點
plt.plot(x_new, prediction, "ro") # 預測數據打點
plt.show()
else:
print("used train or test data is not available !")

 

 




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM