最近完成了一個以圖搜圖的項目,項目總共用時三個多月。記錄一下項目中用到機器學習的地方,以及各種踩過的坑。總的來說,項目分為一下幾個部分:
一、訓練目標函數
項目是在預訓練模型 vgg16 的基礎上進行微調(fine_tune),並將特征的維度從原先的 2048 維降為 1024 維度。
模型的微調又分為以下幾個步驟:
1、設定基礎模型
本次采用預訓練的 VGG16基礎模型,利用其 bottleneck 特征
# 設定基礎模型
base_model = VGG16(weights='./model/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False)
#指定權重路徑
# include_top= False 不加載三層全連接層
2、添加新層
將自己要目標圖片,簡單分類,統計類別(在訓練模型時需要指定類別)
# 添加新層
def add_new_last_layer(base_model, nb_classes): ''' 添加最后的層 :param base_model: 預訓練模型 :param nb_classes: 分類數量 :return: 新的 model ''' x = base_model.output x = GlobalAveragePooling2D()(x) x = Dense(128, activation='relu')(x) #輸出的特征維度 88 predictions = Dense(nb_classes, activation='softmax')(x) model = Model(input=base_model.input, output=predictions) return model
3、凍結 base 層
以前的參數可以使用預訓練好的參數,不需要重新訓練,所以需要凍結,不讓其改變。
def freeze_base_layer(model, base_model): for layer in base_model.layers: layer.trainable = False
4、編譯模型
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics= ['accuracy']) # optimizer: 優化器 # loss: 損失函數,多類的對數損失需要將分類標簽轉換為(將標簽轉化為形如(nb_samples, nb_classes)的二值序列) # metrics: 列表,包含評估模型在訓練和測試時的網絡性能的指標准備訓練數據。
5、訓練
#數據准備 IM_WIDTH, IM_HEIGHT = 224,224 train_dir = './refine_img_data/train' val_dir = './refine_img_data/test' nb_classes = 5 np_epoch = 3 batch_size = 16 nb_train_samples = get_nb_files(train_dir) nb_classes = len(glob.glob(train_dir + '/*')) nb_val_samples = get_nb_files(val_dir) # 根據現有數據,設置新數據生成參數 train_datagen = ImageDataGenerator( preprocessing_function=preprocess_input, rotation_range=30, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True ) test_datagen = ImageDataGenerator( preprocessing_function=preprocess_input, rotation_range=30, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True ) # 從文件夾獲取數據 train_generator = train_datagen.flow_from_directory( train_dir, target_size=(IM_WIDTH, IM_HEIGHT), batch_size=batch_size, class_mode='categorical' ) validation_generator = test_datagen.flow_from_directory( val_dir, target_size=(IM_WIDTH, IM_HEIGHT), batch_size=batch_size, class_mode='categorical' ) # 訓練 history_t1 = model.fit_generator( train_generator, epochs=1, steps_per_epoch=10, validation_data=validation_generator, validation_steps=10, class_weight='auto' )
6、保存模型
將模型保存到指定路徑一般保存為 .h5 格式
model.save('/model/test_model.h5')
二、特征提取
加載我們訓練好的模型,根據需要,取指定層的特征。
# 可用 model.summary() 查看模型結構 #根據模型提取圖片特征 target_size = (224,224) def my_feature(mod, path): img = image.load_img(path,target_size=target_size) img = image.img_to_array(img) img = np.expand_dims(img, axis=0) img = preprocess_input(img) return mod.predict(img) # 創建模型,獲取指定層特征 model_path = './model/my_model.h5' base_model = load_model(model_path) model = Model(inputs=base_model.input, outputs=base_model.get_layer('dense_1').output) # 提取特征 img_path = './my_img/bus.jpg' feat = my_feature(model,img_path) # shape 為 (1,128) print(feat) print(feat.shape) #注意, 當需要提取的圖片特征數量較大,比如千萬以上,需要的時間是比較長的,這時我們可以采用多核與批處理來進行 (python 由於 GIL 的問題對多線程不友好)。 def pre_processs_image(path): if path is not None and os.path.exists(path) and len(path) > 10: try: img = cv2.imread(path, cv2.IMREAD_COLOR) img = cv2.resize(img, (224, 224)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.transpose(2, 0, 1) return [material_id,img, flag] except Exception as err: traceback.print_exc() return None else: logging.error('could not find path: ' + path) return None #cpu 部分,調用多核處理函數,指定核數為 20 with ProcessPoolExecutor(max_workers=20) as executor: feat_paras = list(executor.map(pre_processs_image,, material_batch)) # GPU 部分采用批處理 # TODO
三、創建索引
此處我們使用 Facebook 開源的近鄰索引框架 faiss 。
# create index d = 128 nlist = 100 # 切分數量 nprobe = 8 # 每次查找分片數量 quantizer_img = faiss.IndexFlatL2(d) #根據歐式距離創建索引 image_index = None model_index = None if image_feat_array is not None and len(img_feat_list) > 100: image_index = faiss.IndexIVFFlat(quantizer_img, d, nlist, faiss.METRIC_L2) image_index.train(image_feat_array) image_index.add_with_ids(image_feat_array,image_id_array) image_index.nprobe = nprobe image_index.dont_dealloc_me = quantizer_img # 保存當前索引到指定路徑 faiss.write_index(img_index,path) # 測試當前索引 temp_feat = img_feat_list[1] res_2 = image_index.search(temp_feat, k=5) logging.info('image search result is:' + str(res_2))
四、構建服務
采用Flask 框架, gunicorn為 wsgi 容器。supervisor 管理進程。
1、flask 開發
參考文檔 http://docs.jinkan.org/docs/flask/quickstart.html#a-minimal-application
2、Gunicorn 異步,增加服務穩健性
基礎語法:
Gunicorn –w process_num –b ip:port –k 'gevent' fileName:app
# 注意:此處不選擇 –k 'gevent' 則為同步運行
同步部署:
gunicorn -b 0.0.0.0:9090 my_service:app
異步部署:
gunicorn -b 0.0.0.0:9090 -k gevent my_service:app
用了 Gunicorn 來部署應用后, 對比 flask , qps 提升了一倍。原 flask 框架中由於我的接口中 request 了其他的接口,線程在此處會阻塞,導致程序非常容易假死。改用后,穩定又了極大的提升。
3、Supervisor 部署監控服務
可參考以下文檔 https://www.cnblogs.com/gjack/p/8076419.html
五、總結
項目到這個地方,基本的服務框架已經有了。許多地方只說了大體思路,但是結構是完整。文中的許多用了許多方法工具,如 gunicorn 的異步等, 但是原理卻不甚了解,還需要花功夫去學習。由於上線壓力大,時間緊,許多地方來不及仔細琢磨,肯定有不少紕漏,后面再查漏補缺吧。
