tensorflow 2.0 學習(三)MNIST訓練


用tensorflow2.0 版回顧了一下mnist的學習

代碼如下,感覺這個版本下的mnist學習更簡潔,更方便

關於tensorflow的基礎知識,這里就不更新了,用到什么就到網上搜索相關的知識

# encoding: utf-8

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

#加載下載好的mnist數據庫 60000張訓練 10000張測試 每一張維度(28,28)
path = r'G:\2019\python\mnist.npz'
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
f.close()

#預處理輸入數據
x = 2*tf.convert_to_tensor(x_train, dtype = tf.float32)/255. - 1
x = tf.reshape(x, [-1, 28*28])
y = tf.convert_to_tensor(y_train, dtype=tf.int32)
y = tf.one_hot(y, depth=10)

#第一層輸入256, 第二次輸出128, 第三層輸出10
#第一,二,三層參數w,b
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1))    #正態分布的一種
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))

#將60000組數據切分為600組,每組100個數據
train_db = tf.data.Dataset.from_tensor_slices((x, y)).batch(100)
lr = 0.001      #學習率
losses = []     #儲存每epoch的loss值,便於觀察學習情況

for epoch in range(20):
    #一次性處理100組(x, y)數據
    for step, (x, y) in enumerate(train_db):    #遍歷切分好的數據step:0->599
        with tf.GradientTape() as tape:
            #向前傳播第一,二,三層
            h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256])  #可以直接寫成 +b1
            h1 = tf.nn.relu(h1)
            h2 = h1@w2 + b2
            h2 = tf.nn.relu(h2)
            out = h2@w3 + b3
            #計算mse
            loss = tf.square(y - out)
            loss = tf.reduce_mean(loss)
        #計算參數的梯度,tape.gradient為自動求導函數,loss為目標數據,目的使它越來越接近真實值
        grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
        #更新w,b
        w1.assign_sub(lr*grads[0])  #原地減去給定的值,實現參數的自我更新
        b1.assign_sub(lr*grads[1])
        w2.assign_sub(lr*grads[2])
        b2.assign_sub(lr*grads[3])
        w3.assign_sub(lr*grads[4])
        b3.assign_sub(lr*grads[5])
        #觀察學習情況
        if step%500 == 0:
            print(epoch, step, 'loss:', float(loss))
    #將每epoch的loss情況儲存起來,最后觀察
    losses.append(float(loss))

plt.plot(losses, marker='s', label='training')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.legend()
plt.savefig('exam_mnist_forward.png') plt.show()

觀察結果:

可由注釋理解代碼的含義!下一次更新mnist數據集訓練的進階!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM