VGGNet特征提取
利用VGGnet的預訓練模型來實現圖像的檢索,先用預訓練模型來抽取圖片的特征,然后把待檢索的圖像和數據庫中的所有圖像進行匹配,找出相似度最高的
在jupyter notebook上實現
文件路徑設置:
root|____ code
|____ images|____ img_class_1
|____ img_class_2
|____ img_class_3
|.... .....
|____ img_class_n
|____models
|____queryimg
- root: 根目錄
- images: 存放各類別的圖片文件夾
- img_class_i: 存放相應類別的圖片
- database: 用於存放數據
- queryimg: 存放待檢索圖片
Step 1. 構造特征提取器
這里用了Keras的應用模塊(Keras.applications)提供的帶有預訓練權值的模型
初始化一個模型的時候,會自動下載權重到~/.keras/models/目錄下
詳細參考🔗
這里用VGG16預訓練模型構造一個特征提取器
import numpy as np
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
from keras.preprocessing import image
from numpy import linalg as LA
class VGGNet:
def __init__(self):
self.input_shape = (224, 224, 3)
self.weight = 'imagenet'
self.pooling = 'max'
self.model_vgg = VGG16(weights=self.weight,
input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
pooling=self.pooling, include_top=False)
# 提取vgg16最后一層卷積特征
def vgg_extract_feat(self, img_path):
img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input_vgg(img)
feat = self.model_vgg.predict(img)
norm_feat = feat[0] / LA.norm(feat[0])
return norm_feat
keras.applications.vgg16.VGG16()
參數設置:
include_top: 是否包括頂層的全連接層。
weights: None 代表隨機初始化, 'imagenet' 代表加載在 ImageNet 上預訓練的權值。
input_tensor: 可選,Keras tensor 作為模型的輸入(即 layers.Input() 輸出的 tensor)。
input_shape: 可選,輸入尺寸元組,僅當 include_top=False 時有效,否則輸入形狀必須是 (244, 244, 3)(對於 channels_last 數據格式),或者 (3, 244, 244)(對於 channels_first 數據格式)。它必須擁有 3 個輸入通道,且寬高必須不小於 32。例如 (200, 200, 3) 是一個合法的輸入尺寸。
pooling: 可選,當 include_top 為 False 時,該參數指定了特征提取時的池化方式。
- None 代表不池化,直接輸出最后一層卷積層的輸出,該輸出是一個四維張量。
- 'avg' 代表全局平均池化(GlobalAveragePooling2D),相當於在最后一層卷積層后面再加一層全局平均池化層,輸出是一個二維張量。
- 'max' 代表全局最大池化
classes: 可選,圖片分類的類別數,僅當 include_top 為 True 並且不加載預訓練權值時可用。
Step 2. 保存圖片數據特征
用VGGnet提取圖片特征
把圖片的特征向量和文件路徑存到文件中
import os
import h5py
import numpy as np
root = os.path.abspath('..')
save_path = os.path.join(root,'database','vgg_featureCNN.h5')
print("--------------------------------------------------")
print(" feature extraction starts")
print("--------------------------------------------------")
imgdir = os.path.join(root,'images')
imgpaths = []
for subdir in os.listdir(imgdir)[:]:
curpath = os.path.join(imgdir,subdir)
for imgname in os.listdir(curpath):
imgpaths += [os.path.join(curpath,imgname)] # 添加圖片路徑
feats = [] # 保存圖片特征向量
model = VGGNet()
for i, img_path in enumerate(imgpaths):
norm_feat = model.vgg_extract_feat(img_path)
feats.append(norm_feat)
print("extracting feature from image No. %d , %d images in total" % ((i + 1), len(imgpaths)))
feats = np.array(feats)
print("--------------------------------------------------")
print(" writing feature extraction results ...")
print("--------------------------------------------------")
h5f = h5py.File(save_path, 'w')
h5f.create_dataset('dataset_1', data = feats)
h5f.create_dataset('dataset_2', data = np.string_(imgpaths))
h5f.close()
print(" writing has ended. ")
Step 3. 圖片檢索
把待檢索圖片存到queryimg中, 進行檢索,輸出前maxres張匹配度最高的圖片
import h5py
from cv2 import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
from extract_cnn_vgg16_keras import VGGNet
root = os.path.abspath('..')
save_path = os.path.join(root,'database','vgg_featureCNN.h5')
h5f = h5py.File(save_path, 'r')
feats = h5f['dataset_1'][:]
imgpaths = h5f['dataset_2'][:]
h5f.close()
querydir = os.path.join(root,'queryimg')
# init VGGNet16 model
model = VGGNet()
# 待檢索圖片名
imgname = 'xxx.jpg'
print("--------------------------------------------------")
print(" searching starts")
print("--------------------------------------------------")
# 待檢索圖片地址
querypath = os.path.join(querydir,imgname)
queryImg = cv2.imread(querypath)
queryImg = cv2.cvtColor(queryImg, cv2.COLOR_BGR2RGB)
plt.title("Query Image")
plt.imshow(queryImg)
plt.show()
# 提取待檢索圖片的特征
queryVec = model.vgg_extract_feat(querypath)
# 和數據庫中的每張圖片的特征匹配,計算匹配分數
scores = np.dot(queryVec, feats.T)
# 按匹配分數從大到小排序
rank_ID = np.argsort(scores)[::-1]
rank_score = scores[rank_ID]
maxres = 3 # 檢索出三張相似度最高的圖片
imlist = []
for i, index in enumerate(rank_ID[0:maxres]):
imlist.append(imgpaths[index])
print("image names: " + str(imgpaths[index]) + " scores: %f" % rank_score[i])
print("top %d images in order are: " % maxres, imlist)
# 輸出檢索到的圖片
for i, im in enumerate(imlist):
impath = str(im)[2:-1] # 得到的im是一個byte型的數據格式,需要轉換成字符串
print(impath)
image = cv2.imread(impath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.title("search output %d" % (i + 1))
plt.imshow(image)
plt.show()
RESNet50進行特征提取
RESNet50的計算量比VGG16低一點,跑得更快,同時內存使用量也更小
import numpy as np
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from numpy import linalg as LA
from keras.applications.resnet50 import preprocess_input as preprocess_input_resnet
class RESNet:
def __init__(self):
self.input_shape = (224, 224, 3)
self.weight = 'imagenet'
self.pooling = 'max'
self.model_resnet = ResNet50(weights=self.weight,
input_shape=self.input_shape,
pooling=self.pooling, include_top=False)
# 提取resnet50最后一層卷積特征
def resnet_extract_feat(self, img_path):
img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input_resnet(img)
feat = self.model_resnet.predict(img)
norm_feat = feat[0]/LA.norm(feat[0])
return norm_feat
import os
import h5py
import numpy as np
root = os.path.abspath('..')
model = RESNet()
save_path = os.path.join(root,'models','resnet_featureCNN.h5')
print("--------------------------------------------------")
print(" feature extraction starts")
print("--------------------------------------------------")
imgdir = os.path.join(root,'images')
imgpaths = []
for subdir in os.listdir(imgdir)[:3]:
curpath = os.path.join(imgdir,subdir)
for imgname in os.listdir(curpath):
imgpaths += [os.path.join(curpath,imgname)]
feats = []
model = RESNet()
for i, img_path in enumerate(imgpaths):
norm_feat = model.resnet_extract_feat(img_path)
feats.append(norm_feat)
print("extracting feature from image No. %d , %d images in total" % ((i + 1), len(imgpaths)))
feats = np.array(feats)
print("--------------------------------------------------")
print(" writing feature extraction results ...")
print("--------------------------------------------------")
h5f = h5py.File(save_path, 'w')
h5f.create_dataset('dataset_1', data=feats)
h5f.create_dataset('dataset_2', data=np.string_(imgpaths))
h5f.close()
print(" writing has done. ")
import h5py
from cv2 import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
root = os.path.abspath('..')
save_path = os.path.join(root,'models','resnet_featureCNN.h5')
h5f = h5py.File(save_path, 'r')
feats = h5f['dataset_1'][:]
imgpaths = h5f['dataset_2'][:]
h5f.close()
querydir = os.path.join(root,'queryimg')
model = RESNet()
queryList = ['AK47', "american-flag", 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat'
, 'bathtub', 'bear', 'beer-mug']
imgname = queryList[0] + '.jpg'
print("--------------------------------------------------")
print(" searching starts")
print("--------------------------------------------------")
# read and show query image
querypath = os.path.join(querydir,imgname)
# queryImg = mpimg.imread(querypath)
queryImg = cv2.imread(querypath)
queryImg = cv2.cvtColor(queryImg, cv2.COLOR_BGR2RGB)
plt.title("Query Image")
plt.imshow(queryImg)
plt.show()
# extract query image's feature, compute simlarity score and sort
queryVec = model.resnet_extract_feat(querypath) # 修改此處改變提取特征的網絡
scores = np.dot(queryVec, feats.T)
rank_ID = np.argsort(scores)[::-1]
rank_score = scores[rank_ID]
# number of top retrieved images to show
maxres = 3 # 檢索出三張相似度最高的圖片
imlist = []
for i, index in enumerate(rank_ID[0:maxres]):
imlist.append(imgpaths[index])
print("image names: " + str(imgpaths[index]) + " scores: %f" % rank_score[i])
print("top %d images in order are: " % maxres, imlist)
# show top #maxres retrieved result one by one
for i, im in enumerate(imlist):
impath = str(im)[2:-1]
print(impath)
image = cv2.imread(impath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.title("search output %d" % (i + 1))
plt.imshow(image)
plt.show()
參考:🔗