Tensorflow2(預課程)---8.1、cifar100分類-層方式
一、總結
一句話總結:
全連接神經網絡做cifar100分類不行,簡單測試一下,准確率才20%,需要換別的神經網絡
二、cifar100分類-層方式
博客對應課程的視頻位置:
步驟
1、讀取數據集
2、拆分數據集(拆分成訓練數據集和測試數據集)
3、構建模型
4、訓練模型
5、檢驗模型
需求
cifar100(物品分類)
cifar100這個數據集就像CIFAR-10,除了它有100個類,每個類包含600個圖像。,每類各有500個訓練圖像和100個測試圖像。CIFAR-100中的100個類被分成20個超類。每個圖像都帶有一個“精細”標簽(它所屬的類)和一個“粗糙”標簽(它所屬的超類) 以下是CIFAR-100中的類別列表:
超類 | 類別 |
---|---|
水生哺乳動物 | 海狸,海豚,水獺,海豹,鯨魚 |
魚 | 水族館的魚,比目魚,射線,鯊魚,鱒魚 |
花卉 | 蘭花,罌粟花,玫瑰,向日葵,郁金香 |
食品容器 | 瓶子,碗,罐子,杯子,盤子 |
水果和蔬菜 | 蘋果,蘑菇,橘子,梨,甜椒 |
家用電器 | 時鍾,電腦鍵盤,台燈,電話機,電視機 |
家用家具 | 床,椅子,沙發,桌子,衣櫃 |
昆蟲 | 蜜蜂,甲蟲,蝴蝶,毛蟲,蟑螂 |
大型食肉動物 | 熊,豹,獅子,老虎,狼 |
大型人造戶外用品 | 橋,城堡,房子,路,摩天大樓 |
大自然的戶外場景 | 雲,森林,山,平原,海 |
大雜食動物和食草動物 | 駱駝,牛,黑猩猩,大象,袋鼠 |
中型哺乳動物 | 狐狸,豪豬,負鼠,浣熊,臭鼬 |
非昆蟲無脊椎動物 | 螃蟹,龍蝦,蝸牛,蜘蛛,蠕蟲 |
人 | 寶貝,男孩,女孩,男人,女人 |
爬行動物 | 鱷魚,恐龍,蜥蜴,蛇,烏龜 |
小型哺乳動物 | 倉鼠,老鼠,兔子,母老虎,松鼠 |
樹木 | 楓樹,橡樹,棕櫚,松樹,柳樹 |
車輛1 | 自行車,公共汽車,摩托車,皮卡車,火車 |
車輛2 | 割草機,火箭,有軌電車,坦克,拖拉機 |
Superclass | Classes |
---|---|
aquatic | mammals beaver, dolphin, otter, seal, whale |
fish | aquarium fish, flatfish, ray, shark, trout |
flowers | orchids, poppies, roses, sunflowers, tulips |
food | containers bottles, bowls, cans, cups, plates |
fruit | and vegetables apples, mushrooms, oranges, pears, sweet peppers |
household | electrical devices clock, computer keyboard, lamp, telephone, television |
household | furniture bed, chair, couch, table, wardrobe |
insects | bee, beetle, butterfly, caterpillar, cockroach |
large carnivores | bear, leopard, lion, tiger, wolf |
large man-made outdoor things | bridge, castle, house, road, skyscraper |
large natural outdoor scenes | cloud, forest, mountain, plain, sea |
large omnivores and herbivores | camel, cattle, chimpanzee, elephant, kangaroo |
medium-sized mammals | fox, porcupine, possum, raccoon, skunk |
non-insect invertebrates | crab, lobster, snail, spider, worm |
people | baby, boy, girl, man, woman |
reptiles | crocodile, dinosaur, lizard, snake, turtle |
small mammals | hamster, mouse, rabbit, shrew, squirrel |
trees | maple, oak, palm, pine, willow |
vehicles 1 | bicycle, bus, motorcycle, pickup truck, train |
vehicles 2 | lawn-mower, rocket, streetcar, tank, tractor |
In [1]:
import pandas as pd import numpy as np import tensorflow as tf import matplotlib.pyplot as plt
1、讀取數據集
直接從tensorflow的dataset來讀取數據集即可
In [2]:
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.cifar100.load_data() print(train_x.shape, train_y.shape)
這是32*32的彩色圖,rgb三個通道如何處理呢
In [3]:
plt.imshow(train_x[0]) plt.show()
In [4]:
plt.figure() plt.imshow(train_x[1]) plt.figure() plt.imshow(train_x[2]) plt.show()
In [5]:
print(test_y)
In [6]:
# 像素值 RGB
np.max(train_x[0])
Out[6]:
2、拆分數據集(拆分成訓練數據集和測試數據集)
上一步做了拆分數據集的工作
In [7]:
# 圖片數據如何歸一化
# 直接除255即可 train_x = train_x/255 test_x = test_x/255
In [8]:
# 像素值 RGB
np.max(train_x[0])
Out[8]:
In [9]:
train_y=train_y.flatten() test_y=test_y.flatten() train_y = tf.one_hot(train_y, depth=100) test_y = tf.one_hot(test_y, depth=100) print(test_y.shape)
3、構建模型
應該構建一個怎么樣的模型:
輸入是32*32*3維,輸出是一個label,是一個10分類問題,
需要one_hot編碼么,如果是one_hot編碼,那么輸出是10維
也就是 32*32*3->n->10,可以試下3072->1024->512->256->128->10
In [10]:
# 構建容器
model = tf.keras.Sequential() # 輸入層 # 將多維數據(60000, 32, 32, 3)變成一維 # 把圖像扁平化成一個向量 model.add(tf.keras.layers.Flatten(input_shape=(32,32,3))) # 中間層 model.add(tf.keras.layers.Dense(1024,activation='relu')) model.add(tf.keras.layers.Dense(512,activation='relu')) model.add(tf.keras.layers.Dense(256,activation='relu')) model.add(tf.keras.layers.Dense(128,activation='relu')) # 輸出層 model.add(tf.keras.layers.Dense(100,activation='softmax')) # 模型的結構 model.summary()
太玄學了,增加層(比如在128和10之間增加32)並不能使准確率增加
4、訓練模型
In [11]:
# 配置優化函數和損失器
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['acc']) # 開始訓練 history = model.fit(train_x,train_y,epochs=50,validation_data=(test_x,test_y))
In [12]:
plt.plot(history.epoch,history.history.get('loss')) plt.title("train data loss") plt.show()
In [13]:
plt.plot(history.epoch,history.history.get('val_loss')) plt.title("test data loss") plt.show()
In [14]:
plt.plot(history.epoch,history.history.get('acc')) plt.title("train data acc") plt.show()
In [15]:
plt.plot(history.epoch,history.history.get('val_acc')) plt.title("test data acc") plt.show()
5、檢驗模型
In [16]:
# 看一下模型的預測能力
pridict_y=model.predict(test_x) print(pridict_y) print(test_y)
In [17]:
# 在pridict_y中找最大值的索引,橫向
pridict_y = tf.argmax(pridict_y, axis=1) print(pridict_y) # test_y = tf.argmax(test_y, axis=1) print(test_y)
In [18]:
plt.figure() plt.imshow(test_x[0]) plt.figure() plt.imshow(test_x[1]) plt.figure() plt.imshow(test_x[2]) plt.figure() plt.imshow(test_x[3]) plt.show()
In [ ]: