基於VGG-16的海量圖像檢索系統(以圖搜圖升級版)


 檢索系統原理:

  圖像檢索過程簡單說來就是對圖片數據庫的每張圖片抽取特征(一般形式為特征向量),存儲於數據庫中,對於待檢索圖片,抽取同樣的特征向量,然后並對該向量和數據庫中向量的距離(相似度計算),找出最接近的一些特征向量,其對應的圖片即為檢索結果。[1]

    【論文解析概述】下圖為ImageNet比賽中使用的卷積神經網絡;中間圖為調整后,在第7層和output層之間添加隱層(假設為128個神經元)后的卷積神經網絡,我們將復用ImageNet中得到最終模型的前7層權重做fine-tuning,得到第7層、8層和output層之間的權重。下方圖為實際檢索過程,對於所有的圖片做卷積神經網絡前向運算得到第7層4096維特征向量和第8層128維輸出(設定閾值0.5之后可以轉成01二值檢索向量),對於待檢索的圖片,同樣得到4096維特征向量和128維01二值檢索向量,在數據庫中查找二值檢索向量對應『桶』內圖片,比對4096維特征向量之間距離,做重拍即得到最終結果。圖上的檢索例子比較直觀,對於待檢索的”鷹”圖像,算得二值檢索向量為101010,取出桶內圖片(可以看到基本也都為鷹),比對4096維特征向量之間距離,重新排序拿得到最后的檢索結果。

原理部分詳見論文,以下是代碼實現:

 

開發環境:

  # windows 10
  # tensorflow-gpu 1.8 + keras
  # python 3.6 

執行示例:

# 對database文件夾內圖片進行特征提取,建立索引文件featureCNN.h5
python index.py -database database -index featureCNN.h5

# 使用database文件夾內001_accordion_image_0001.jpg作為測試圖片,在database內以featureCNN.h5進行近似圖片查找,並顯示最近似的3張圖片
python query_online.py -query database/001_accordion_image_0001.jpg -index featureCNN.h5 -result database

 

1、抽取特征:extract_cnn_vgg16_keras.py

# -*- coding: utf-8 -*-

import numpy as np
from numpy import linalg as LA

from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input

class VGGNet:
    def __init__(self):
        # weights: 'imagenet'
        # pooling: 'max' or 'avg'
        # input_shape: (width, height, 3), width and height should >= 48
        self.input_shape = (224, 224, 3)
        self.weight = 'imagenet'
        self.pooling = 'max'
        self.model = VGG16(weights = self.weight, input_shape = (self.input_shape[0], self.input_shape[1], self.input_shape[2]), pooling = self.pooling, include_top = False)
        self.model.predict(np.zeros((1, 224, 224 , 3)))

    '''
    Use vgg16 model to extract features
    Output normalized feature vector
    '''
    def extract_feat(self, img_path):
        img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input(img)
        feat = self.model.predict(img)
        norm_feat = feat[0]/LA.norm(feat[0])
        return norm_feat

2、存儲索引:index.py

# -*- coding: utf-8 -*-
import os
import h5py
import numpy as np
import argparse

from extract_cnn_vgg16_keras import VGGNet


ap = argparse.ArgumentParser()
ap.add_argument("-database", required = True,
    help = "Path to database which contains images to be indexed")
ap.add_argument("-index", required = True,
    help = "Name of index file")
args = vars(ap.parse_args())

'''
 Returns a list of filenames for all jpg images in a directory. 
'''
def get_imlist(path):
    return [os.path.join(path,f) for f in os.listdir(path) if f.endswith('.jpg')]

'''
 Extract features and index the images
'''
if __name__ == "__main__":

    db = args["database"]
    img_list = get_imlist(db)
    
    print ("--------------------------------------------------")
    print ("         feature extraction starts")
    print ("--------------------------------------------------")
    
    feats = []
    names = []

    model = VGGNet()
    for i, img_path in enumerate(img_list):
        norm_feat = model.extract_feat(img_path)
        img_name = os.path.split(img_path)[1]
        feats.append(norm_feat)
        names.append(img_name.encode())
        print ("extracting feature from image No. %d , %d images in total" %((i+1), len(img_list)))

    feats = np.array(feats)
    # directory for storing extracted features
    output = args["index"]
    
    print ("--------------------------------------------------")
    print ("      writing feature extraction results ...")
    print ("--------------------------------------------------")
    
    h5f = h5py.File(output, 'w')
    h5f.create_dataset('dataset_1', data = feats)
    h5f.create_dataset('dataset_2', data = names)
    h5f.close()

3、在線搜索部分query_online.py:

# -*- coding: utf-8 -*-
from extract_cnn_vgg16_keras import VGGNet

import numpy as np
import h5py

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import argparse


ap = argparse.ArgumentParser()
ap.add_argument("-query", required = True,
    help = "Path to query which contains image to be queried")
ap.add_argument("-index", required = True,
    help = "Path to index")
ap.add_argument("-result", required = True,
    help = "Path for output retrieved images")
args = vars(ap.parse_args())


# read in indexed images' feature vectors and corresponding image names
h5f = h5py.File(args["index"],'r')
feats = h5f['dataset_1'][:]
imgNames = h5f['dataset_2'][:]
h5f.close()
        
print ("--------------------------------------------------")
print ("               searching starts")
print ("--------------------------------------------------")
    
# read and show query image
queryDir = args["query"]
queryImg = mpimg.imread(queryDir)
plt.title("Query Image")
plt.imshow(queryImg)
plt.show()

# init VGGNet16 model
model = VGGNet()

# extract query image's feature, compute simlarity score and sort
queryVec = model.extract_feat(queryDir)
scores = np.dot(queryVec, feats.T)
rank_ID = np.argsort(scores)[::-1]
rank_score = scores[rank_ID]
#print rank_ID
#print rank_score


# number of top retrieved images to show
maxres = 3
imlist = [imgNames[index] for i,index in enumerate(rank_ID[0:maxres])]
print ("top %d images in order are: " %maxres, imlist)
 

# show top #maxres retrieved result one by one
for i,im in enumerate(imlist):
    image = mpimg.imread(args["result"]+"/"+str(im,encoding='utf-8'))
    plt.title("search output %d" %(i+1))
    plt.imshow(image)
    plt.show()

 

 

參考及引用:

利用VGG16提取特征:https://keras-cn.readthedocs.io/en/latest/other/application/

圖片檢索方法:https://github.com/willard-yuan

論文推薦:https://github.com/willard-yuan/awesome-cbir-papers

論文:http://www.iis.sinica.edu.tw/~kevinlin311.tw/cvprw15.pdf

[1] : https://blog.csdn.net/han_xiaoyang/article/details/50856583


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM