TensorFlow高层次机器学习API (tf.contrib.learn)


TensorFlow高层次机器学习API (tf.contrib.learn)

1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格式数据

2.tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)

3.classifer.fit 训练模型

4.classifier.evaluate 评价模型

5.classifier.predict 预测新样本

完整代码:

复制代码
 1 from __future__ import absolute_import  2 from __future__ import division  3 from __future__ import print_function  4  5 import tensorflow as tf  6 import numpy as np  7  8 # Data sets  9 IRIS_TRAINING = "iris_training.csv" 10 IRIS_TEST = "iris_test.csv" 11 12 # Load datasets. 13 training_set = tf.contrib.learn.datasets.base.load_csv_with_header( 14 filename=IRIS_TRAINING, 15 target_dtype=np.int, 16 features_dtype=np.float32) 17 test_set = tf.contrib.learn.datasets.base.load_csv_with_header( 18 filename=IRIS_TEST, 19 target_dtype=np.int, 20 features_dtype=np.float32) 21 22 # Specify that all features have real-value data 23 feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] 24 25 # Build 3 layer DNN with 10, 20, 10 units respectively. 26 classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, 27 hidden_units=[10, 20, 10], 28 n_classes=3, 29 model_dir="/tmp/iris_model") 30 31 # Fit model. 32 classifier.fit(x=training_set.data, 33 y=training_set.target, 34 steps=2000) 35 36 # Evaluate accuracy. 37 accuracy_score = classifier.evaluate(x=test_set.data, 38 y=test_set.target)["accuracy"] 39 print('Accuracy: {0:f}'.format(accuracy_score)) 40 41 # Classify two new flower samples. 42 new_samples = np.array( 43 [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float) 44 y = list(classifier.predict(new_samples, as_iterable=True)) 45 print('Predictions: {}'.format(str(y)))
复制代码

 结果:

Accuracy:0.966667


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM