相似圖像檢索


相似圖像檢測

VGGNet特征提取

利用VGGnet的預訓練模型來實現圖像的檢索,先用預訓練模型來抽取圖片的特征,然后把待檢索的圖像和數據庫中的所有圖像進行匹配,找出相似度最高的

在jupyter notebook上實現

文件路徑設置:
root|____ code
       |____ images|____ img_class_1
                           |____ img_class_2
                           |____ img_class_3
                           |.... .....
                           |____ img_class_n
       |____models
       |____queryimg

  • root: 根目錄
  • images: 存放各類別的圖片文件夾
  • img_class_i: 存放相應類別的圖片
  • database: 用於存放數據
  • queryimg: 存放待檢索圖片

Step 1. 構造特征提取器

這里用了Keras的應用模塊(Keras.applications)提供的帶有預訓練權值的模型

初始化一個模型的時候,會自動下載權重到~/.keras/models/目錄下

詳細參考🔗

這里用VGG16預訓練模型構造一個特征提取器

import numpy as np
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
from keras.preprocessing import image
from numpy import linalg as LA
class VGGNet:
    def __init__(self):
        self.input_shape = (224, 224, 3)
        self.weight = 'imagenet'
        self.pooling = 'max'
        self.model_vgg = VGG16(weights=self.weight,
                               input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
                               pooling=self.pooling, include_top=False)
    # 提取vgg16最后一層卷積特征
    def vgg_extract_feat(self, img_path):
        img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input_vgg(img)
        feat = self.model_vgg.predict(img)
        norm_feat = feat[0] / LA.norm(feat[0])
        return norm_feat
keras.applications.vgg16.VGG16()

參數設置:
include_top: 是否包括頂層的全連接層。
weights: None 代表隨機初始化, 'imagenet' 代表加載在 ImageNet 上預訓練的權值。
input_tensor: 可選,Keras tensor 作為模型的輸入(即 layers.Input() 輸出的 tensor)。
input_shape: 可選,輸入尺寸元組,僅當 include_top=False 時有效,否則輸入形狀必須是 (244, 244, 3)(對於 channels_last 數據格式),或者 (3, 244, 244)(對於 channels_first 數據格式)。它必須擁有 3 個輸入通道,且寬高必須不小於 32。例如 (200, 200, 3) 是一個合法的輸入尺寸。
pooling: 可選,當 include_top 為 False 時,該參數指定了特征提取時的池化方式。

  • None 代表不池化,直接輸出最后一層卷積層的輸出,該輸出是一個四維張量。
  • 'avg' 代表全局平均池化(GlobalAveragePooling2D),相當於在最后一層卷積層后面再加一層全局平均池化層,輸出是一個二維張量。
  • 'max' 代表全局最大池化

classes: 可選,圖片分類的類別數,僅當 include_top 為 True 並且不加載預訓練權值時可用。

Step 2. 保存圖片數據特征

用VGGnet提取圖片特征
把圖片的特征向量和文件路徑存到文件中

import os
import h5py
import numpy as np

root = os.path.abspath('..')
save_path = os.path.join(root,'database','vgg_featureCNN.h5')

print("--------------------------------------------------")
print("         feature extraction starts")
print("--------------------------------------------------")
imgdir = os.path.join(root,'images')

imgpaths = []
for subdir in os.listdir(imgdir)[:]:
    curpath = os.path.join(imgdir,subdir)
    for imgname in os.listdir(curpath):
        imgpaths += [os.path.join(curpath,imgname)]             # 添加圖片路徑

feats = []          # 保存圖片特征向量

model = VGGNet()
for i, img_path in enumerate(imgpaths):
    norm_feat = model.vgg_extract_feat(img_path)
    feats.append(norm_feat)
    print("extracting feature from image No. %d , %d images in total" % ((i + 1), len(imgpaths)))

feats = np.array(feats)
print("--------------------------------------------------")
print("      writing feature extraction results ...")
print("--------------------------------------------------")

h5f = h5py.File(save_path, 'w')
h5f.create_dataset('dataset_1', data = feats)
h5f.create_dataset('dataset_2', data = np.string_(imgpaths))
h5f.close()
print("             writing has ended.            ")

Step 3. 圖片檢索

把待檢索圖片存到queryimg中, 進行檢索,輸出前maxres張匹配度最高的圖片


import h5py
from cv2 import cv2
import matplotlib.pyplot as plt
import numpy as np
import os

from extract_cnn_vgg16_keras import VGGNet

root = os.path.abspath('..')
save_path = os.path.join(root,'database','vgg_featureCNN.h5')
h5f = h5py.File(save_path, 'r')
feats = h5f['dataset_1'][:]
imgpaths = h5f['dataset_2'][:]
h5f.close()

querydir = os.path.join(root,'queryimg')

# init VGGNet16 model
model = VGGNet()


# 待檢索圖片名
imgname = 'xxx.jpg'

print("--------------------------------------------------")
print("               searching starts")
print("--------------------------------------------------")

# 待檢索圖片地址
querypath = os.path.join(querydir,imgname)
queryImg = cv2.imread(querypath)
queryImg = cv2.cvtColor(queryImg, cv2.COLOR_BGR2RGB)
plt.title("Query Image")
plt.imshow(queryImg)
plt.show()


# 提取待檢索圖片的特征
queryVec = model.vgg_extract_feat(querypath)

# 和數據庫中的每張圖片的特征匹配,計算匹配分數
scores = np.dot(queryVec, feats.T)
# 按匹配分數從大到小排序
rank_ID = np.argsort(scores)[::-1]
rank_score = scores[rank_ID]

maxres = 3  # 檢索出三張相似度最高的圖片
imlist = []
for i, index in enumerate(rank_ID[0:maxres]):
    imlist.append(imgpaths[index])
    print("image names: " + str(imgpaths[index]) + " scores: %f" % rank_score[i])
print("top %d images in order are: " % maxres, imlist)

# 輸出檢索到的圖片
for i, im in enumerate(imlist):
    impath = str(im)[2:-1]        # 得到的im是一個byte型的數據格式,需要轉換成字符串
    print(impath)
    image = cv2.imread(impath)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    plt.title("search output %d" % (i + 1))
    plt.imshow(image)
    plt.show()

RESNet50進行特征提取

RESNet50的計算量比VGG16低一點,跑得更快,同時內存使用量也更小

import numpy as np
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from numpy import linalg as LA
from keras.applications.resnet50 import preprocess_input as preprocess_input_resnet
class RESNet:
    def __init__(self):
        self.input_shape = (224, 224, 3)
        self.weight = 'imagenet'
        self.pooling = 'max'
        self.model_resnet = ResNet50(weights=self.weight,
                               input_shape=self.input_shape,
                               pooling=self.pooling, include_top=False)
        
    # 提取resnet50最后一層卷積特征
    def resnet_extract_feat(self, img_path):
        img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input_resnet(img)
        feat = self.model_resnet.predict(img)
        norm_feat = feat[0]/LA.norm(feat[0])
        return norm_feat


import os
import h5py
import numpy as np
root = os.path.abspath('..')
model = RESNet()

save_path = os.path.join(root,'models','resnet_featureCNN.h5')

print("--------------------------------------------------")
print("         feature extraction starts")
print("--------------------------------------------------")
imgdir = os.path.join(root,'images')
imgpaths = []
for subdir in os.listdir(imgdir)[:3]:
    curpath = os.path.join(imgdir,subdir)
    for imgname in os.listdir(curpath):
        imgpaths += [os.path.join(curpath,imgname)]

feats = []

model = RESNet()
for i, img_path in enumerate(imgpaths):
    norm_feat = model.resnet_extract_feat(img_path)
    feats.append(norm_feat)
    print("extracting feature from image No. %d , %d images in total" % ((i + 1), len(imgpaths)))

feats = np.array(feats)
print("--------------------------------------------------")
print("      writing feature extraction results ...")
print("--------------------------------------------------")

h5f = h5py.File(save_path, 'w')
h5f.create_dataset('dataset_1', data=feats)
h5f.create_dataset('dataset_2', data=np.string_(imgpaths))
h5f.close()

print("             writing has done.            ")



import h5py
from cv2 import cv2
import matplotlib.pyplot as plt
import numpy as np
import os

root = os.path.abspath('..')
save_path = os.path.join(root,'models','resnet_featureCNN.h5')
h5f = h5py.File(save_path, 'r')
feats = h5f['dataset_1'][:]
imgpaths = h5f['dataset_2'][:]
h5f.close()

querydir = os.path.join(root,'queryimg')

model = RESNet()


queryList = ['AK47', "american-flag", 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat'
            , 'bathtub', 'bear', 'beer-mug']
imgname = queryList[0] + '.jpg'


print("--------------------------------------------------")
print("               searching starts")
print("--------------------------------------------------")

# read and show query image
querypath = os.path.join(querydir,imgname)
# queryImg = mpimg.imread(querypath)
queryImg = cv2.imread(querypath)
queryImg = cv2.cvtColor(queryImg, cv2.COLOR_BGR2RGB)
plt.title("Query Image")
plt.imshow(queryImg)
plt.show()


# extract query image's feature, compute simlarity score and sort
queryVec = model.resnet_extract_feat(querypath)  # 修改此處改變提取特征的網絡
scores = np.dot(queryVec, feats.T)
rank_ID = np.argsort(scores)[::-1]
rank_score = scores[rank_ID]

# number of top retrieved images to show
maxres = 3  # 檢索出三張相似度最高的圖片
imlist = []
for i, index in enumerate(rank_ID[0:maxres]):
    imlist.append(imgpaths[index])
    print("image names: " + str(imgpaths[index]) + " scores: %f" % rank_score[i])
print("top %d images in order are: " % maxres, imlist)
# show top #maxres retrieved result one by one
for i, im in enumerate(imlist):
    impath = str(im)[2:-1]
    print(impath)
    image = cv2.imread(impath)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    plt.title("search output %d" % (i + 1))
    plt.imshow(image)
    plt.show()

參考:🔗


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM