相似圖像搜索從訓練到服務全過程


最近完成了一個以圖搜圖的項目,項目總共用時三個多月。記錄一下項目中用到機器學習的地方,以及各種踩過的坑。總的來說,項目分為一下幾個部分:

 一、訓練目標函數 

1、    設定基礎模型

2、    添加新層

3、    凍結 base 層

4、    編譯模型

5、    訓練

6、    保存模型

二、特征提取

三、創建索引

四、構建服務

1、flask 開發 

2、Gunicorn 異步,增加服務穩健性

3、Supervisor 部署監控服務

五、總結  

 

 

一、訓練目標函數

項目是在預訓練模型 vgg16 的基礎上進行微調(fine_tune),並將特征的維度從原先的 2048 維降為 1024 維度。

模型的微調又分為以下幾個步驟:

1、設定基礎模型

本次采用預訓練的 VGG16基礎模型,利用其 bottleneck 特征

 # 設定基礎模型

base_model = VGG16(weights='./model/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False)

 #指定權重路徑

# include_top= False 不加載三層全連接層

 

2、添加新層

將自己要目標圖片,簡單分類,統計類別(在訓練模型時需要指定類別)

# 添加新層

 

def add_new_last_layer(base_model, nb_classes):

    '''
    添加最后的層
    :param base_model: 預訓練模型
    :param nb_classes: 分類數量
    :return: 新的 model
    '''
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='relu')(x) #輸出的特征維度 88
    predictions = Dense(nb_classes, activation='softmax')(x)
    model = Model(input=base_model.input, output=predictions)
    return model

 

 

3、凍結 base 層

以前的參數可以使用預訓練好的參數,不需要重新訓練,所以需要凍結,不讓其改變。

 

def freeze_base_layer(model, base_model):

        for layer in base_model.layers:

        layer.trainable = False

 

 

 4、編譯模型

model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics= ['accuracy'])

# optimizer: 優化器

# loss: 損失函數,多類的對數損失需要將分類標簽轉換為(將標簽轉化為形如(nb_samples, nb_classes)的二值序列)

# metrics: 列表,包含評估模型在訓練和測試時的網絡性能的指標准備訓練數據。

 

5、訓練

#數據准備
IM_WIDTH, IM_HEIGHT = 224,224
train_dir = './refine_img_data/train'
val_dir = './refine_img_data/test'
nb_classes = 5
np_epoch = 3
batch_size = 16
nb_train_samples = get_nb_files(train_dir)
nb_classes = len(glob.glob(train_dir + '/*'))
nb_val_samples = get_nb_files(val_dir)

# 根據現有數據,設置新數據生成參數
train_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)

test_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)

# 從文件夾獲取數據
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(IM_WIDTH, IM_HEIGHT),
batch_size=batch_size,
class_mode='categorical'
)

validation_generator = test_datagen.flow_from_directory(
val_dir,
target_size=(IM_WIDTH, IM_HEIGHT),
batch_size=batch_size,
class_mode='categorical'
)

# 訓練
history_t1 = model.fit_generator(
train_generator,
epochs=1,
steps_per_epoch=10,
validation_data=validation_generator,
validation_steps=10,
class_weight='auto'
)

6、保存模型

將模型保存到指定路徑一般保存為 .h5 格式

 model.save('/model/test_model.h5')

 

     

二、特征提取

加載我們訓練好的模型,根據需要,取指定層的特征。

 

# 可用 model.summary() 查看模型結構

#根據模型提取圖片特征

target_size = (224,224)

def my_feature(mod, path):
    img = image.load_img(path,target_size=target_size)
    img = image.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = preprocess_input(img)
    return mod.predict(img)

 

# 創建模型,獲取指定層特征
model_path = './model/my_model.h5'
base_model = load_model(model_path)
model = Model(inputs=base_model.input, outputs=base_model.get_layer('dense_1').output)

 

# 提取特征
img_path = './my_img/bus.jpg'
feat = my_feature(model,img_path) # shape 為 (1,128)
print(feat)
print(feat.shape)

#注意, 當需要提取的圖片特征數量較大,比如千萬以上,需要的時間是比較長的,這時我們可以采用多核與批處理來進行 (python 由於 GIL 的問題對多線程不友好)。
def pre_processs_image(path):
    if path is not None and os.path.exists(path) and len(path) > 10:
      try:
          img = cv2.imread(path, cv2.IMREAD_COLOR)
          img = cv2.resize(img, (224, 224))
          img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
          img = img.transpose(2, 0, 1)
          return [material_id,img, flag]
      except Exception as err:
          traceback.print_exc()
          return None
    else:
    logging.error('could not find path: ' + path)
    return None

 

#cpu 部分,調用多核處理函數,指定核數為 20
with ProcessPoolExecutor(max_workers=20) as executor:
feat_paras = list(executor.map(pre_processs_image,, material_batch))


# GPU 部分采用批處理
# TODO

 

 

三、創建索引

此處我們使用 Facebook 開源的近鄰索引框架 faiss 。

 

# create index
d = 128
nlist = 100 # 切分數量
nprobe = 8 # 每次查找分片數量
quantizer_img = faiss.IndexFlatL2(d) #根據歐式距離創建索引

 
image_index = None
model_index = None

if image_feat_array is not None and len(img_feat_list) > 100:
  image_index = faiss.IndexIVFFlat(quantizer_img, d, nlist, faiss.METRIC_L2)
  image_index.train(image_feat_array)
  image_index.add_with_ids(image_feat_array,image_id_array)
  image_index.nprobe = nprobe
  image_index.dont_dealloc_me = quantizer_img

# 保存當前索引到指定路徑
faiss.write_index(img_index,path)

# 測試當前索引
temp_feat = img_feat_list[1]
res_2 = image_index.search(temp_feat, k=5)
logging.info('image search result is:' + str(res_2))

 

 

四、構建服務

采用Flask 框架, gunicorn為 wsgi 容器。supervisor 管理進程。

1、flask 開發

參考文檔 http://docs.jinkan.org/docs/flask/quickstart.html#a-minimal-application

2、Gunicorn 異步,增加服務穩健性

基礎語法:

Gunicorn –w process_num –b ip:port –k 'gevent' fileName:app

# 注意:此處不選擇 –k 'gevent' 則為同步運行

 

同步部署:

gunicorn -b 0.0.0.0:9090 my_service:app

 

異步部署:

gunicorn -b 0.0.0.0:9090 -k gevent my_service:app

用了 Gunicorn 來部署應用后, 對比 flask , qps 提升了一倍。原 flask 框架中由於我的接口中 request 了其他的接口,線程在此處會阻塞,導致程序非常容易假死。改用后,穩定又了極大的提升。

 

3、Supervisor 部署監控服務

可參考以下文檔 https://www.cnblogs.com/gjack/p/8076419.html

 

五、總結

項目到這個地方,基本的服務框架已經有了。許多地方只說了大體思路,但是結構是完整。文中的許多用了許多方法工具,如 gunicorn 的異步等, 但是原理卻不甚了解,還需要花功夫去學習。由於上線壓力大,時間緊,許多地方來不及仔細琢磨,肯定有不少紕漏,后面再查漏補缺吧。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM