tensorflow 2.0 學習(四)MNIST 訓練與測試


這次的mnist學習加入了測試集,看看學習的准確率,代碼如下

# encoding: utf-8

import tensorflow as tf
import matplotlib.pyplot as plt

#加載下載好的mnist數據庫 60000張訓練 10000張測試 每一張維度(28,28)
path = r'G:\2019\python\mnist.npz'
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path)

#第一層輸入256, 第二次輸出128, 第三層輸出10
#第一,二,三層參數w,b
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1))    #正態分布的一種
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))

#兩種數據預處理的方法
#(一)預處理訓練數據
x = tf.convert_to_tensor(x_train, dtype = tf.float32)/255.    #0:1  ;   -1:1(不適合訓練,准確度不高)
x = tf.reshape(x, [-1, 28*28])
y = tf.convert_to_tensor(y_train, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
#將60000組訓練數據切分為600組,每組100個數據
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.shuffle(60000)      #盡量與樣本空間一樣大
train_db = train_db.batch(100)          #128


#(二)自定義預處理測試函數
def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.     #先將類型轉化為float32,再歸一到0-1
    x = tf.reshape(x, [-1, 28*28])              #不知道x數量,用-1代替,轉化為一維784個數據
    y = tf.cast(y, dtype=tf.int32)              #轉化為整型32
    y = tf.one_hot(y, depth=10)                 #訓練數據所需的one-hot編碼
    return x, y

#將10000組測試數據預處理
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(10000)
test_db = test_db.batch(100)        #128
test_db = test_db.map(preprocess)

lr = 0.001      #學習率
losses = []     #儲存每epoch的loss值,便於觀察學習情況
acc = []        #准確率

for epoch in range(30):     #20
    #一次性處理100組(x, y)數據
    for step, (x, y) in enumerate(train_db):    #遍歷切分好的數據step:0->599
        with tf.GradientTape() as tape:
            #向前傳播第一,二,三層
            h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256])  #可以直接寫成 +b1
            h1 = tf.nn.relu(h1)
            h2 = h1@w2 + b2
            h2 = tf.nn.relu(h2)
            out = h2@w3 + b3

            #計算mse
            loss = tf.square(y - out)
            loss = tf.reduce_mean(loss)
        #計算參數的梯度,tape.gradient為自動求導函數,loss為目標數據,目的使它越來越接近真實值
        grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
        #更新w,b
        w1.assign_sub(lr*grads[0])  #原地減去給定的值,實現參數的自我更新
        b1.assign_sub(lr*grads[1])
        w2.assign_sub(lr*grads[2])
        b2.assign_sub(lr*grads[3])
        w3.assign_sub(lr*grads[4])
        b3.assign_sub(lr*grads[5])
        #觀察學習情況
        if step%100 == 0:
            print('訓練第 ',epoch,'',', 第',step,'步, ','loss:', float(loss))
            losses.append(float(loss))          #將每100step后的loss情況儲存起來,最后觀察

        if step%500 == 0:
            total, total_correct = 0., 0.
            for x, y in test_db:
                h1 = x @ w1 + b1
                h1 = tf.nn.relu(h1)
                h2 = h1 @ w2 + b2
                h2 = tf.nn.relu(h2)
                out = h2 @ w3 + b3

                pred = tf.argmax(out, axis=1)  # 選取概率最大的類別
                y = tf.argmax(y, axis=1)  # 類似於one-hot逆編碼
                correct = tf.equal(pred, y)  # 比較真實值和預測值是否相等
                total += x.shape[0]
                # 統計正確的個數
                total_correct += tf.reduce_sum(tf.cast(correct, dtype=tf.int32)).numpy()
            print('訓練第 ',epoch,'',', 第',step,'步, ', 'Evaluate Acc:', total_correct/total)
            acc.append(total_correct/total)

#plt.subplot(121)
x1 = [i*100 for i in range(len(losses))]
plt.plot(x1, losses, marker='s', label='training')
plt.xlabel('Step')
plt.ylabel('MSE')
plt.legend()
#plt.savefig('exam_mnist_forward.png')
#plt.show()

#plt.subplot(122)
plt.figure()
x2 = [i for i in range(len(acc))]
plt.plot(x2, acc, 'r',marker='d', label='testing')
plt.xlabel('Step')
plt.ylabel('Accuracy')
plt.legend()
#plt.savefig('test_mnist_forward.png')
plt.show()

誤差何准確率如下

發現和書中類似,但要注意的如下:

(1)數據預處理時,打散值選擇和數據空間一樣大;

(2)數據處理選擇0-1之間,而不用(-1 :1),是因為后者學習效率不理想!

(3)代碼還可以進行優化處理!

總的來說,代碼還是容易理解,使用也更加簡潔!

下一次更新,全連接網絡,關於汽車油耗的預測。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM