Tensorflow函數式API的使用


在我們使用tensorflow時,如果不能使用函數式api進行編程,那么一些復雜的神經網絡結構就不會實現出來,只能使用簡單的單向模型進行一層一層地堆疊。如果稍微復雜一點,遇到了Resnet這種帶有殘差模塊的神經網絡,那么用簡單的神經網絡堆疊的方式則不可能把這種網絡堆疊出來。下面我們來使用函數式API來編寫一個簡單的全連接神經網絡:
首先導包:

from tensorflow import keras
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

導入圖片數據集:mnist

(train_image,train_label),(test_image,test_label)=tf.keras.datasets.fashion_mnist.load_data()

歸一化:

train_image=train_image/255
test_image=test_image/255#進行數據的歸一化,加快計算的進程

搭建全連接神經網絡:

input=keras.Input(shape=(28,28))
x=keras.layers.Flatten()(input)#調用input
x=keras.layers.Dense(32,activation="relu")(x)
x=keras.layers.Dropout(0.5)(x)#一層一層的進行調用上一層的結果
output=keras.layers.Dense(10,activation="softmax")(x)
model=keras.Model(inputs=input,outputs=output)
model.summary()

輸出:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 32)                25120     
_________________________________________________________________
dropout (Dropout)            (None, 32)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                330       
=================================================================
Total params: 25,450
Trainable params: 25,450
Non-trainable params: 0

擬合模型:

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss="sparse_categorical_crossentropy",
              metrics=['acc']
)
history=model.fit(train_image,
                  train_label,
                  epochs=15,
                  validation_data=(test_image,test_label))

輸出:

Train on 60000 samples, validate on 10000 samples
Epoch 1/15
60000/60000 [==============================] - 4s 64us/sample - loss: 0.8931 - acc: 0.6737 - val_loss: 0.5185 - val_acc: 0.8160
Epoch 2/15
60000/60000 [==============================] - 3s 57us/sample - loss: 0.6757 - acc: 0.7508 - val_loss: 0.4805 - val_acc: 0.8230
Epoch 3/15
60000/60000 [==============================] - 3s 50us/sample - loss: 0.6336 - acc: 0.7647 - val_loss: 0.4587 - val_acc: 0.8369
Epoch 4/15
60000/60000 [==============================] - 3s 49us/sample - loss: 0.6174 - acc: 0.7689 - val_loss: 0.4712 - val_acc: 0.8294
Epoch 5/15
60000/60000 [==============================] - 3s 48us/sample - loss: 0.6080 - acc: 0.7732 - val_loss: 0.4511 - val_acc: 0.8404
Epoch 6/15
60000/60000 [==============================] - 3s 48us/sample - loss: 0.5932 - acc: 0.7773 - val_loss: 0.4545 - val_acc: 0.8407
Epoch 7/15
60000/60000 [==============================] - 3s 47us/sample - loss: 0.5886 - acc: 0.7772 - val_loss: 0.4394 - val_acc: 0.8428
Epoch 8/15
60000/60000 [==============================] - 3s 52us/sample - loss: 0.5820 - acc: 0.7788 - val_loss: 0.4338 - val_acc: 0.8506
Epoch 9/15
60000/60000 [==============================] - 3s 48us/sample - loss: 0.5742 - acc: 0.7839 - val_loss: 0.4393 - val_acc: 0.8454
Epoch 10/15
60000/60000 [==============================] - 3s 49us/sample - loss: 0.5713 - acc: 0.7847 - val_loss: 0.4422 - val_acc: 0.8477
Epoch 11/15
60000/60000 [==============================] - 3s 47us/sample - loss: 0.5642 - acc: 0.7858 - val_loss: 0.4325 - val_acc: 0.8488
Epoch 12/15
60000/60000 [==============================] - 3s 48us/sample - loss: 0.5582 - acc: 0.7873 - val_loss: 0.4294 - val_acc: 0.8492
Epoch 13/15
60000/60000 [==============================] - 3s 48us/sample - loss: 0.5574 - acc: 0.7882 - val_loss: 0.4263 - val_acc: 0.8523
Epoch 14/15
60000/60000 [==============================] - 3s 48us/sample - loss: 0.5524 - acc: 0.7888 - val_loss: 0.4350 - val_acc: 0.8448
Epoch 15/15
60000/60000 [==============================] - 3s 47us/sample - loss: 0.5486 - acc: 0.7901 - val_loss: 0.4297 - val_acc: 0.8493

最后驗證集的精度達到了84%,這是一個僅僅使用全連接神經網絡和softmax就能夠得到的一個很不錯的結果了!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM