AI - TensorFlow - 分類與回歸(Classification vs Regression)


分類與回歸

分類(Classification)與回歸(Regression)的區別在於輸出變量的類型
通俗理解,定量輸出稱為回歸,或者說是連續變量預測;定性輸出稱為分類,或者說是離散變量預測。

回歸問題的預測結果是連續的,通常是用來預測一個值,如預測房價、未來的天氣情況等等。
一個比較常見的回歸算法是線性回歸算法(LR,Linear Regression)。
回歸分析用在神經網絡上,其最上層不需要加上softmax函數,而是直接對前一層累加即可。
回歸是對真實值的一種逼近預測。

分類問題的預測結果是離散的,是用於將事物打上一個標簽,通常結果為離散值。
分類通常是建立在回歸之上,分類的最后一層通常要使用softmax函數進行判斷其所屬類別。
分類並沒有逼近的概念,最終正確結果只有一個,錯誤的就是錯誤的,不會有相近的概念。
最常見的分類方法是邏輯回歸(Logistic Regression),或者叫邏輯分類。

 

MNIST數據集

MNIST(Mixed National Institute of Standards and Technology database)是一個計算機視覺數據集;

  • 官方下載地址:http://yann.lecun.com/exdb/mnist/
  • 包含70000張手寫數字的灰度圖片,其中60000張為訓練圖像和10000張為測試圖像;
  • 每一張圖片都是28*28個像素點大小的灰度圖像;

如果無法從網絡下載MNIST數據集,可從官方下載,然后存放在當前腳本目錄下的新建MNIST_data目錄即可;

  •  MNIST_data\train-images-idx3-ubyte.gz
  • MNIST_data\train-labels-idx1-ubyte.gz
  • MNIST_data\t10k-images-idx3-ubyte.gz
  • MNIST_data\t10k-labels-idx1-ubyte.gz

 

示例程序

 1 # coding=utf-8
 2 from __future__ import print_function
 3 import tensorflow as tf
 4 from tensorflow.examples.tutorials.mnist import input_data  # MNIST數據集
 5 import os
 6 
 7 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 8 
 9 old_v = tf.logging.get_verbosity()
10 tf.logging.set_verbosity(tf.logging.ERROR)
11 
12 mnist = input_data.read_data_sets('MNIST_data', one_hot=True)  # 准備數據(如果本地沒有數據,將從網絡下載)
13 
14 
15 def add_layer(inputs, in_size, out_size, activation_function=None, ):
16     Weights = tf.Variable(tf.random_normal([in_size, out_size]))
17     biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, )
18     Wx_plus_b = tf.matmul(inputs, Weights) + biases
19     if activation_function is None:
20         outputs = Wx_plus_b
21     else:
22         outputs = activation_function(Wx_plus_b, )
23     return outputs
24 
25 
26 def compute_accuracy(v_xs, v_ys):
27     global prediction
28     y_pre = sess.run(prediction, feed_dict={xs: v_xs})
29     correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(v_ys, 1))
30     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
31     result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys})
32     return result
33 
34 
35 xs = tf.placeholder(tf.float32, [None, 784])  # 輸入數據是784(28*28)個特征
36 ys = tf.placeholder(tf.float32, [None, 10])  # 輸出數據是10個特征
37 
38 prediction = add_layer(xs, 784, 10, activation_function=tf.nn.softmax)  # 激勵函數為softmax
39 
40 cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
41                                               reduction_indices=[1]))  # loss函數(最優化目標函數)選用交叉熵函數
42 
43 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)  # train方法(最優化算法)采用梯度下降法
44 
45 sess = tf.Session()
46 init = tf.global_variables_initializer()
47 sess.run(init)
48 
49 for i in range(1000):
50     batch_xs, batch_ys = mnist.train.next_batch(100)  # 每次只取100張圖片,免得數據太多訓練太慢
51     sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys})
52     if i % 50 == 0:  # 每訓練50次輸出預測精度
53         print(compute_accuracy(
54             mnist.test.images, mnist.test.labels))

 

程序運行結果:

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
0.146
0.6316
0.7347
0.7815
0.8095
0.8198
0.8306
0.837
0.8444
0.8547
0.8544
0.8578
0.8651
0.8649
0.8705
0.8704
0.8741
0.8719
0.8753
0.8756

 

問題處理

問題現象

執行程序提示“Please use tf.data to implement this functionality.”等信息

WARNING:tensorflow:From D:/Anliven/Anliven-Code/PycharmProjects/TempTest/TempTest_2.py:13: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\anliven\AppData\Local\conda\conda\envs\mlcc\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Extracting MNIST_data\train-images-idx3-ubyte.gz
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\anliven\AppData\Local\conda\conda\envs\mlcc\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data\train-labels-idx1-ubyte.gz
......
......

處理方法

參考鏈接:https://stackoverflow.com/questions/49901806/tensorflow-importing-mnist-warnings

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM