AlexNet網絡解決經典貓狗分類問題


1.下載數據集 

2.對下載好的數據集分類

代碼:

'''
准備數據集
'''

import os
import tensorflow as tf
import cv2
import glob
import numpy as np
from PIL import Image


os.environ['TF_CPP_MIN_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

def prepare_train():
for i in range(2000):
train_img_path1 = r'.\train\cat.%d.jpg'%i
train_img_path2 = r'.\train\dog.%d.jpg' % i
img1 = Image.open(train_img_path1)
img2 = Image.open(train_img_path2)
x_s = 227
y_s = 227
new_img1 = img1.resize((x_s, y_s), Image.ANTIALIAS)
new_img1.save(r'.\train1\cat.%d.jpg'%i)
new_img2 = img2.resize((x_s, y_s), Image.ANTIALIAS)
new_img2.save(r'.\train1\dog.%d.jpg' % i)
print('第{}張圖片保存完成!'.format(i))
# prepare_train()
def prepare_test():
for i in range(2000,3000):
train_img_path1 = r'.\train\cat.%d.jpg'%i
train_img_path2 = r'.\train\dog.%d.jpg' % i
img1 = Image.open(train_img_path1)
img2 = Image.open(train_img_path2)
x_s = 227
y_s = 227
new_img1 = img1.resize((x_s, y_s), Image.ANTIALIAS)
new_img1.save(r'.\test1\cat.%d.jpg'%(i-2000))
new_img2 = img2.resize((x_s, y_s), Image.ANTIALIAS)
new_img2.save(r'.\test1\dog.%d.jpg'%(i-2000))
print('第{}張圖片保存完成!'.format(i-2000))
prepare_test()

3.對數據集序列化,得到TFrecord

代碼:

'''
將train_data 進行序列化

'''

import os
import tensorflow as tf

os.environ['TF_CPP_MIN_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

# train_data
# train_dir = './train1/'
# tfrecord_file = './train.tfrecords'
#
# train_filenames = [train_dir + filename for filename in os.listdir(train_dir)]
# train_labels = [0] * 2000 + [1] * 2000 # 貓為0,狗為1
# print(len(train_labels))


# with tf.io.TFRecordWriter(tfrecord_file) as writer:
# for image_file, label in zip(train_filenames, train_labels):
# image = open(image_file, 'rb').read()
# feature = {
# 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
# 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
# }
# example = tf.train.Example(features=tf.train.Features(feature=feature))
# writer.write(example.SerializeToString())
# writer.close()

# -------------------------------------------------------------------------------------------------------------------
# test_data
test_dir = './test1/'
tfrecord_file = './test.tfrecords'

test_filenames = [test_dir + filename for filename in os.listdir(test_dir)]
test_labels = [0] * 1000 + [1] * 1000 # 貓為0,狗為1
# print(len(test_labels))

with tf.io.TFRecordWriter(tfrecord_file) as writer:
for image_file, label in zip(test_filenames, test_labels):
# print(image_file, label)
image = open(image_file, 'rb').read()
feature = {
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
writer.close()

 

4.對數據集反序列化,得到Tensor格式的原始數據

代碼:

'''
反序列化,得到Tensor格式的圖片和label
'''
tfrecord_file1 = './train.tfrecords'
tfrecord_file2 = './test.tfrecords'
raw_dataset1 = tf.data.TFRecordDataset(tfrecord_file1)
raw_dataset2 = tf.data.TFRecordDataset(tfrecord_file2)


feature_description = {
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)
}

def _parse_example(example_string): # 將 TFRecord 文件中的每一個序列化的 tf.train.Example 解碼
feature_dict = tf.io.parse_single_example(example_string, feature_description)
feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image']) # 解碼JPEG圖片
return feature_dict['image'], feature_dict['label']

train_data = raw_dataset1.map(_parse_example)
test_data = raw_dataset2.map(_parse_example)

5.構建網絡,進行訓練測試調參

代碼:

# 准備數據集
def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32)/255.
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=2)
return x, y

batch_size = 128
epochs = 200

train_db = train_data.shuffle(4000).batch(batch_size)
train_db = train_db.map(preprocess)
test_db = test_data.shuffle(2000).batch(batch_size)
test_db = test_db.map(preprocess)

# 定義網絡
class AlexNet(keras.Model):
def __init__(self):
super(AlexNet, self).__init__()


self.conv = Sequential([
# unit1
layers.Conv2D(96, (11, 11), padding='valid', strides=4, activation=tf.nn.relu),
layers.MaxPool2D(pool_size=(3, 3), strides=2, padding='same'),
layers.BatchNormalization(), # 使用bn層代替LRN

# unit2
layers.Conv2D(256, (5, 5), padding='same', strides=1, activation=tf.nn.relu),
layers.MaxPool2D(pool_size=(3, 3), strides=2, padding='same'),
layers.BatchNormalization(),

# unit3
layers.Conv2D(384, (3, 3), padding='same', strides=1, activation=tf.nn.relu),

# unit4
layers.Conv2D(384, (3, 3), padding='same', strides=1, activation=tf.nn.relu),

# unit5
layers.Conv2D(256, (3, 3), padding='valid', strides=1, activation=tf.nn.relu),
layers.MaxPool2D(pool_size=(3, 3), strides=2, padding='same'),
layers.BatchNormalization(),

])

self.fc = Sequential([
# fc1
layers.Dense(4096, activation=tf.nn.relu),
layers.Dropout(0.4),
# fc2
layers.Dense(4096, activation=tf.nn.relu),
layers.Dropout(0.4),
# fc3
layers.Dense(4096, activation=tf.nn.relu),
layers.Dropout(0.4),
# fc4
layers.Dense(2, activation=tf.nn.relu)
])

def call(self, inputs, training=None):
x = inputs
out = self.conv(x)
out = tf.reshape(out, (-1, 6*6*256))
out = self.fc(out)
return out

model = AlexNet()

# 檢查網絡輸出的shape
# x = tf.random.normal((1, 227, 227, 3))
# out = model(x)
# print(out.shape)

# 輸出網絡模型的結構
# model.build(input_shape=(None, 227, 227, 3))
# model.summary()

def main():
model.compile(optimizer=optimizers.Adam(lr=0.00001),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['acc'])
model.fit(train_db, epochs=15, validation_data=test_db, validation_freq=2)
model.save_weights('./checkpoint/weights.ckpt')
print('save weights')

model.evaluate(test_db)


if __name__ == '__main__':
# main()
i = 0
m = 0
k = 0
model.load_weights('./checkpoint/weights.ckpt')
# for image, label in test_db:
#
# y = tf.argmax(model(image), axis=1)
# label = tf.argmax(label, axis=1)
# # y = model(image)
# if y == label:
# m += 1
# else:
# k += 1
# i += 1
# if i == 2000:
# print('准確率:', m/i)
# break


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM