TensorFlow高層次機器學習API (tf.contrib.learn)


TensorFlow高層次機器學習API (tf.contrib.learn)

1.tf.contrib.learn.datasets.base.load_csv_with_header 加載csv格式數據

2.tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)

3.classifer.fit 訓練模型

4.classifier.evaluate 評價模型

5.classifier.predict 預測新樣本

完整代碼:

復制代碼
 1 from __future__ import absolute_import  2 from __future__ import division  3 from __future__ import print_function  4  5 import tensorflow as tf  6 import numpy as np  7  8 # Data sets  9 IRIS_TRAINING = "iris_training.csv" 10 IRIS_TEST = "iris_test.csv" 11 12 # Load datasets. 13 training_set = tf.contrib.learn.datasets.base.load_csv_with_header( 14 filename=IRIS_TRAINING, 15 target_dtype=np.int, 16 features_dtype=np.float32) 17 test_set = tf.contrib.learn.datasets.base.load_csv_with_header( 18 filename=IRIS_TEST, 19 target_dtype=np.int, 20 features_dtype=np.float32) 21 22 # Specify that all features have real-value data 23 feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] 24 25 # Build 3 layer DNN with 10, 20, 10 units respectively. 26 classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, 27 hidden_units=[10, 20, 10], 28 n_classes=3, 29 model_dir="/tmp/iris_model") 30 31 # Fit model. 32 classifier.fit(x=training_set.data, 33 y=training_set.target, 34 steps=2000) 35 36 # Evaluate accuracy. 37 accuracy_score = classifier.evaluate(x=test_set.data, 38 y=test_set.target)["accuracy"] 39 print('Accuracy: {0:f}'.format(accuracy_score)) 40 41 # Classify two new flower samples. 42 new_samples = np.array( 43 [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float) 44 y = list(classifier.predict(new_samples, as_iterable=True)) 45 print('Predictions: {}'.format(str(y)))
復制代碼

 結果:

Accuracy:0.966667


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM