檢索系統原理:
圖像檢索過程簡單說來就是對圖片數據庫的每張圖片抽取特征(一般形式為特征向量),存儲於數據庫中,對於待檢索圖片,抽取同樣的特征向量,然后並對該向量和數據庫中向量的距離(相似度計算),找出最接近的一些特征向量,其對應的圖片即為檢索結果。[1]
【論文解析概述】下圖為ImageNet比賽中使用的卷積神經網絡;中間圖為調整后,在第7層和output層之間添加隱層(假設為128個神經元)后的卷積神經網絡,我們將復用ImageNet中得到最終模型的前7層權重做fine-tuning,得到第7層、8層和output層之間的權重。下方圖為實際檢索過程,對於所有的圖片做卷積神經網絡前向運算得到第7層4096維特征向量和第8層128維輸出(設定閾值0.5之后可以轉成01二值檢索向量),對於待檢索的圖片,同樣得到4096維特征向量和128維01二值檢索向量,在數據庫中查找二值檢索向量對應『桶』內圖片,比對4096維特征向量之間距離,做重拍即得到最終結果。圖上的檢索例子比較直觀,對於待檢索的”鷹”圖像,算得二值檢索向量為101010,取出桶內圖片(可以看到基本也都為鷹),比對4096維特征向量之間距離,重新排序拿得到最后的檢索結果。
原理部分詳見論文,以下是代碼實現:
開發環境:
# windows 10 # tensorflow-gpu 1.8 + keras # python 3.6
執行示例:
# 對database文件夾內圖片進行特征提取,建立索引文件featureCNN.h5 python index.py -database database -index featureCNN.h5 # 使用database文件夾內001_accordion_image_0001.jpg作為測試圖片,在database內以featureCNN.h5進行近似圖片查找,並顯示最近似的3張圖片 python query_online.py -query database/001_accordion_image_0001.jpg -index featureCNN.h5 -result database
1、抽取特征:extract_cnn_vgg16_keras.py
# -*- coding: utf-8 -*- import numpy as np from numpy import linalg as LA from keras.applications.vgg16 import VGG16 from keras.preprocessing import image from keras.applications.vgg16 import preprocess_input class VGGNet: def __init__(self): # weights: 'imagenet' # pooling: 'max' or 'avg' # input_shape: (width, height, 3), width and height should >= 48 self.input_shape = (224, 224, 3) self.weight = 'imagenet' self.pooling = 'max' self.model = VGG16(weights = self.weight, input_shape = (self.input_shape[0], self.input_shape[1], self.input_shape[2]), pooling = self.pooling, include_top = False) self.model.predict(np.zeros((1, 224, 224 , 3))) ''' Use vgg16 model to extract features Output normalized feature vector ''' def extract_feat(self, img_path): img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1])) img = image.img_to_array(img) img = np.expand_dims(img, axis=0) img = preprocess_input(img) feat = self.model.predict(img) norm_feat = feat[0]/LA.norm(feat[0]) return norm_feat
2、存儲索引:index.py
# -*- coding: utf-8 -*- import os import h5py import numpy as np import argparse from extract_cnn_vgg16_keras import VGGNet ap = argparse.ArgumentParser() ap.add_argument("-database", required = True, help = "Path to database which contains images to be indexed") ap.add_argument("-index", required = True, help = "Name of index file") args = vars(ap.parse_args()) ''' Returns a list of filenames for all jpg images in a directory. ''' def get_imlist(path): return [os.path.join(path,f) for f in os.listdir(path) if f.endswith('.jpg')] ''' Extract features and index the images ''' if __name__ == "__main__": db = args["database"] img_list = get_imlist(db) print ("--------------------------------------------------") print (" feature extraction starts") print ("--------------------------------------------------") feats = [] names = [] model = VGGNet() for i, img_path in enumerate(img_list): norm_feat = model.extract_feat(img_path) img_name = os.path.split(img_path)[1] feats.append(norm_feat) names.append(img_name.encode()) print ("extracting feature from image No. %d , %d images in total" %((i+1), len(img_list))) feats = np.array(feats) # directory for storing extracted features output = args["index"] print ("--------------------------------------------------") print (" writing feature extraction results ...") print ("--------------------------------------------------") h5f = h5py.File(output, 'w') h5f.create_dataset('dataset_1', data = feats) h5f.create_dataset('dataset_2', data = names) h5f.close()
3、在線搜索部分query_online.py:
# -*- coding: utf-8 -*- from extract_cnn_vgg16_keras import VGGNet import numpy as np import h5py import matplotlib.pyplot as plt import matplotlib.image as mpimg import argparse ap = argparse.ArgumentParser() ap.add_argument("-query", required = True, help = "Path to query which contains image to be queried") ap.add_argument("-index", required = True, help = "Path to index") ap.add_argument("-result", required = True, help = "Path for output retrieved images") args = vars(ap.parse_args()) # read in indexed images' feature vectors and corresponding image names h5f = h5py.File(args["index"],'r') feats = h5f['dataset_1'][:] imgNames = h5f['dataset_2'][:] h5f.close() print ("--------------------------------------------------") print (" searching starts") print ("--------------------------------------------------") # read and show query image queryDir = args["query"] queryImg = mpimg.imread(queryDir) plt.title("Query Image") plt.imshow(queryImg) plt.show() # init VGGNet16 model model = VGGNet() # extract query image's feature, compute simlarity score and sort queryVec = model.extract_feat(queryDir) scores = np.dot(queryVec, feats.T) rank_ID = np.argsort(scores)[::-1] rank_score = scores[rank_ID] #print rank_ID #print rank_score # number of top retrieved images to show maxres = 3 imlist = [imgNames[index] for i,index in enumerate(rank_ID[0:maxres])] print ("top %d images in order are: " %maxres, imlist) # show top #maxres retrieved result one by one for i,im in enumerate(imlist): image = mpimg.imread(args["result"]+"/"+str(im,encoding='utf-8')) plt.title("search output %d" %(i+1)) plt.imshow(image) plt.show()
參考及引用:
利用VGG16提取特征:https://keras-cn.readthedocs.io/en/latest/other/application/
圖片檢索方法:https://github.com/willard-yuan
論文推薦:https://github.com/willard-yuan/awesome-cbir-papers
論文:http://www.iis.sinica.edu.tw/~kevinlin311.tw/cvprw15.pdf
[1] : https://blog.csdn.net/han_xiaoyang/article/details/50856583