Tensorflow2(预课程)---8.1、cifar100分类-层方式
一、总结
一句话总结:
全连接神经网络做cifar100分类不行,简单测试一下,准确率才20%,需要换别的神经网络
二、cifar100分类-层方式
博客对应课程的视频位置:
步骤
1、读取数据集
2、拆分数据集(拆分成训练数据集和测试数据集)
3、构建模型
4、训练模型
5、检验模型
需求
cifar100(物品分类)
cifar100这个数据集就像CIFAR-10,除了它有100个类,每个类包含600个图像。,每类各有500个训练图像和100个测试图像。CIFAR-100中的100个类被分成20个超类。每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类) 以下是CIFAR-100中的类别列表:
超类 | 类别 |
---|---|
水生哺乳动物 | 海狸,海豚,水獭,海豹,鲸鱼 |
鱼 | 水族馆的鱼,比目鱼,射线,鲨鱼,鳟鱼 |
花卉 | 兰花,罂粟花,玫瑰,向日葵,郁金香 |
食品容器 | 瓶子,碗,罐子,杯子,盘子 |
水果和蔬菜 | 苹果,蘑菇,橘子,梨,甜椒 |
家用电器 | 时钟,电脑键盘,台灯,电话机,电视机 |
家用家具 | 床,椅子,沙发,桌子,衣柜 |
昆虫 | 蜜蜂,甲虫,蝴蝶,毛虫,蟑螂 |
大型食肉动物 | 熊,豹,狮子,老虎,狼 |
大型人造户外用品 | 桥,城堡,房子,路,摩天大楼 |
大自然的户外场景 | 云,森林,山,平原,海 |
大杂食动物和食草动物 | 骆驼,牛,黑猩猩,大象,袋鼠 |
中型哺乳动物 | 狐狸,豪猪,负鼠,浣熊,臭鼬 |
非昆虫无脊椎动物 | 螃蟹,龙虾,蜗牛,蜘蛛,蠕虫 |
人 | 宝贝,男孩,女孩,男人,女人 |
爬行动物 | 鳄鱼,恐龙,蜥蜴,蛇,乌龟 |
小型哺乳动物 | 仓鼠,老鼠,兔子,母老虎,松鼠 |
树木 | 枫树,橡树,棕榈,松树,柳树 |
车辆1 | 自行车,公共汽车,摩托车,皮卡车,火车 |
车辆2 | 割草机,火箭,有轨电车,坦克,拖拉机 |
Superclass | Classes |
---|---|
aquatic | mammals beaver, dolphin, otter, seal, whale |
fish | aquarium fish, flatfish, ray, shark, trout |
flowers | orchids, poppies, roses, sunflowers, tulips |
food | containers bottles, bowls, cans, cups, plates |
fruit | and vegetables apples, mushrooms, oranges, pears, sweet peppers |
household | electrical devices clock, computer keyboard, lamp, telephone, television |
household | furniture bed, chair, couch, table, wardrobe |
insects | bee, beetle, butterfly, caterpillar, cockroach |
large carnivores | bear, leopard, lion, tiger, wolf |
large man-made outdoor things | bridge, castle, house, road, skyscraper |
large natural outdoor scenes | cloud, forest, mountain, plain, sea |
large omnivores and herbivores | camel, cattle, chimpanzee, elephant, kangaroo |
medium-sized mammals | fox, porcupine, possum, raccoon, skunk |
non-insect invertebrates | crab, lobster, snail, spider, worm |
people | baby, boy, girl, man, woman |
reptiles | crocodile, dinosaur, lizard, snake, turtle |
small mammals | hamster, mouse, rabbit, shrew, squirrel |
trees | maple, oak, palm, pine, willow |
vehicles 1 | bicycle, bus, motorcycle, pickup truck, train |
vehicles 2 | lawn-mower, rocket, streetcar, tank, tractor |
In [1]:
import pandas as pd import numpy as np import tensorflow as tf import matplotlib.pyplot as plt
1、读取数据集
直接从tensorflow的dataset来读取数据集即可
In [2]:
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.cifar100.load_data() print(train_x.shape, train_y.shape)
这是32*32的彩色图,rgb三个通道如何处理呢
In [3]:
plt.imshow(train_x[0]) plt.show()
In [4]:
plt.figure() plt.imshow(train_x[1]) plt.figure() plt.imshow(train_x[2]) plt.show()
In [5]:
print(test_y)
In [6]:
# 像素值 RGB
np.max(train_x[0])
Out[6]:
2、拆分数据集(拆分成训练数据集和测试数据集)
上一步做了拆分数据集的工作
In [7]:
# 图片数据如何归一化
# 直接除255即可 train_x = train_x/255 test_x = test_x/255
In [8]:
# 像素值 RGB
np.max(train_x[0])
Out[8]:
In [9]:
train_y=train_y.flatten() test_y=test_y.flatten() train_y = tf.one_hot(train_y, depth=100) test_y = tf.one_hot(test_y, depth=100) print(test_y.shape)
3、构建模型
应该构建一个怎么样的模型:
输入是32*32*3维,输出是一个label,是一个10分类问题,
需要one_hot编码么,如果是one_hot编码,那么输出是10维
也就是 32*32*3->n->10,可以试下3072->1024->512->256->128->10
In [10]:
# 构建容器
model = tf.keras.Sequential() # 输入层 # 将多维数据(60000, 32, 32, 3)变成一维 # 把图像扁平化成一个向量 model.add(tf.keras.layers.Flatten(input_shape=(32,32,3))) # 中间层 model.add(tf.keras.layers.Dense(1024,activation='relu')) model.add(tf.keras.layers.Dense(512,activation='relu')) model.add(tf.keras.layers.Dense(256,activation='relu')) model.add(tf.keras.layers.Dense(128,activation='relu')) # 输出层 model.add(tf.keras.layers.Dense(100,activation='softmax')) # 模型的结构 model.summary()
太玄学了,增加层(比如在128和10之间增加32)并不能使准确率增加
4、训练模型
In [11]:
# 配置优化函数和损失器
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['acc']) # 开始训练 history = model.fit(train_x,train_y,epochs=50,validation_data=(test_x,test_y))
In [12]:
plt.plot(history.epoch,history.history.get('loss')) plt.title("train data loss") plt.show()
In [13]:
plt.plot(history.epoch,history.history.get('val_loss')) plt.title("test data loss") plt.show()
In [14]:
plt.plot(history.epoch,history.history.get('acc')) plt.title("train data acc") plt.show()
In [15]:
plt.plot(history.epoch,history.history.get('val_acc')) plt.title("test data acc") plt.show()
5、检验模型
In [16]:
# 看一下模型的预测能力
pridict_y=model.predict(test_x) print(pridict_y) print(test_y)
In [17]:
# 在pridict_y中找最大值的索引,横向
pridict_y = tf.argmax(pridict_y, axis=1) print(pridict_y) # test_y = tf.argmax(test_y, axis=1) print(test_y)
In [18]:
plt.figure() plt.imshow(test_x[0]) plt.figure() plt.imshow(test_x[1]) plt.figure() plt.imshow(test_x[2]) plt.figure() plt.imshow(test_x[3]) plt.show()
In [ ]: