基於milvus搭建“以圖搜圖”服務(附代碼)


“以圖搜圖”服務需要的關鍵功能和准備工作:

1 圖像向量化功能,可選的模型有很多,本例選用resnet網絡提取圖像特征;

2 milvus建表,用milvus存放圖像特征,通過唯一ID(此處稱:milvus_id)與圖像一一對應,sql建表將milvus_id作為唯一索引,存放圖像的其他信息(如url,來源等);

3 異步添加圖像,同步搜索圖像,添加圖像的量通常會很大,因此采用異步批量的方式將圖像特征加載到milvus,圖像添加服務會將每次的請求信息存到sql,寫個腳本專門用來定時批量加載圖像特征到milvus,由於是異步操作,可能會出現重復加載的情況,此處使用redis進行去重。圖像搜索的請求通常會比圖像添加少很多,因此圖像搜索使采用同步方式返回結果;

(總結:需建立三個表:milvus表1,存放圖像特征;sql表2,存放圖像信息,數據與milvus表1一一對應;sql表3,存放圖像添加請求信息,用於圖像特征異步批量加載到milvus)

“以圖搜圖”服務關鍵功能及代碼(代碼僅做參考)

1 圖像向量化

"""
功能:圖像向量化
"""
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
from numpy import linalg as LA
import time

model = ResNet50(weights='imagenet')
# model.summary()


def img2feature(img_path, input_dim=224):  # 圖像路徑???圖像數據
    img = image.load_img(img_path, target_size=(input_dim, input_dim))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    x = model.predict(x)
    x = x / LA.norm(x)
    return x


def main():
    img_path = '1.jpg'
    t0 = time.time()
    res = img2feature(img_path)
    print(time.time() - t0, res.shape)
    # print(res, type(res), res.shape)


if __name__ == "__main__":
    main()

 2 milvus表的操作

# coding:utf-8
from functools import reduce
import numpy as np
import time
from img2feature import img2feature
from pymilvus import (
    connections, list_collections,
    FieldSchema, CollectionSchema, DataType,
    Collection, utility
)


field_name = 'image_feature'
host = '***.***.***.***'
port = '19530'
dim = 1000
default_fields = [
    FieldSchema(name="milvus_id", dtype=DataType.INT64, is_primary=True),
    FieldSchema(name="feature", dtype=DataType.FLOAT_VECTOR, dim=dim),
    FieldSchema(name="create_time", dtype=DataType.INT64)
]


# create_table
def create_table():
    connections.connect(host=host, port=port)
    # create collection

    default_schema = CollectionSchema(fields=default_fields, description="test collection")

    print(f"\nCreate collection...")
    collection = Collection(name=field_name, schema=default_schema)
    print(f"\nCreate index...")
    default_index = {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
    collection.create_index(field_name="feature", index_params=default_index)
    print(print(f"\nCreate index...is OKOKOKOKOK"))
    collection.load()


# insert data
def insert_data():
    connections.connect(host=host, port=port)
    default_schema = CollectionSchema(fields=default_fields, description="test collection")
    collection = Collection(name=field_name, schema=default_schema)
    vectors = img2feature('1.jpg').tolist()[0]
    print(type(vectors), len(vectors))
    data1 = [
        [123],
        [vectors],
        [int(time.time())]
    ]
    collection.insert(data1)
    print('insert compete')


# search data
def search_data():
    print('search')
    connections.connect(host=host, port=port)
    collection = Collection(name=field_name)
    print('連接成功')

    # 首次查詢建立索引和load()
    # default_index = {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
    # print(f"\nCreate index...")
    # collection.create_index(field_name="feature", index_params=default_index)
    # print(print(f"\nCreate index...is OKOKOKOKOK"))
    # collection.load()
    # exit()

    vectors = img2feature('1.jpg').tolist()[0]

    topK = 10
    search_params = {"metric_type": "L2", "params": {"nprobe": 10}}

    res = collection.search(
        [vectors],
        "feature",
        search_params,
        topK,
        "create_time > {}".format(0),
        output_fields=["milvus_id"]
    )
    print('>>>', res)
    for hits in res:
        print(len(hits))
        for hit in hits:
            print(hit)
    print('查詢結束')


def show_nums():
    connections.connect(host=host, port=port)
    collection = Collection(name=field_name)
    print('ok')
    print(collection.num_entities)


# delete data
def delete_table():
    connections.connect(host=host, port=port)
    default_schema = CollectionSchema(fields=default_fields, description="test collection")
    collection = Collection(name=field_name, schema=default_schema)
    print('>>>', utility.has_collection(field_name))
    collection.drop()
    print('>>>', utility.has_collection(field_name))


if __name__ == "__main__":
    t1 = time.time()
    # create_table()
    # insert_data()
    # search_data()
    show_nums()
    # delete_table()
    print('time cost: {}'.format(time.time() - t1))

 3 創建sql表2、表3

 4 圖像添加、搜索服務

from rest_framework.views import APIView as View
from kpdjango.response import SucessAPIResponse, ErrorAPIResponse
from kpmysql.base import Kpmysql
from core import search_image
import kplog
import logging
log = logging.getLogger("console")


class add_image(View):
    def post(self, requests):
        try:
            db = Kpmysql.connect("db168")
            cur = db.cursor()
            image_info = requests.POST.get('image_info')
            image_path = requests.POST.get('image_path')
            sql = "INSERT INTO t_image_search_image_add_log(image_path, info) VALUES(%s, %s)"
            cur.execute(sql, (image_path, image_info))
            db.commit()
            log.info('添加圖像成功:{}-{}'.format(image_path, image_info))
            return SucessAPIResponse(msg="Success")
        except Exception as e:
            log.info('添加圖像失敗:{}'.format(e))
            return ErrorAPIResponse(msg="Fail")


class search_image(View):
    def post(self, requests):
        try:
            image_path = requests.POST.get('image_path')
            res = search_image(image_path)
            log.info('查詢圖像成功:{}-{}'.format(image_path, res))
            return SucessAPIResponse(msg="Success", data={"data": res})
        except Exception as e:
            log.info('查詢圖像成功:{}'.format(e))
            return ErrorAPIResponse(msg="Fail")

 5 圖像異步批量加載

import time, datetime
from kpmysql.base import Kpmysql
from core import insert_data_many
from concurrent.futures import ThreadPoolExecutor
import redis
from conf.setting import REDIS
from core import str2time
import kplog
import logging

log = logging.getLogger("console")
log_addimgs = logging.getLogger("console_addimgs")


def worker(datas):
    try:
        redis_cli = redis.Redis(host=REDIS.get('host'), port=REDIS.get('port'), password=REDIS.get('password'),
                                db=REDIS.get('db'))
        dics = []
        ids = []
        for data in datas:
            if redis_cli.zscore('image_search', str(data[0])):  # 基於redis去重
                continue
            dics.append({'image_path': data[1], 'create_time': data[2]})
            ids.append((data[0]))
            redis_cli.zadd('image_search', {str(data[0]): str2time(data[2])})
        # 數據插入milvus
        insert_data_many(dics)
        # 更新 set t_image_search_image_add_log is_load=1
        sql_update = """UPDATE t_image_search_image_add_log SET is_load=1 WHERE id=%s"""
        db168 = Kpmysql.connect("db168")
        cur168 = db168.cursor()
        cur168.executemany(sql_update, ids)
        db168.commit()
    except Exception as e:
        print(e)


def main():
    max_workers = 20  # 最大線程數
    pool = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='Thread')
    task_list = []
    init_time = datetime.datetime.now() - datetime.timedelta(hours=13)
    create_time_init = '2020-2-22 00:00:00'
    while True:
        now = datetime.datetime.now()
        diff = now - init_time
        if diff.seconds > 3600:
            # 加載 t_image_search_image_add_log where is_load=0 數據
            db168 = Kpmysql.connect("db168")
            cur168 = db168.cursor()
            sql = """SELECT id, image_path, create_time FROM t_image_search_image_add_log WHERE is_load=0 and create_time >= %s ORDER BY create_time"""
            cur168.execute(sql, create_time_init)
            datas = cur168.fetchall()
            create_time_init = datas[-1][2]


            while True:
                for _i, _n in enumerate(task_list):
                    if _n.done():
                        task_list.pop(_i)
                if len(task_list) < int(max_workers * 0.9):
                    break
            task_list.append(pool.submit(worker, datas))
            init_time = now
        time.sleep(600)


if __name__ == "__main__":
    main()

 優化(重點)

經過實際測試和使用的建議:

1. keras在調用GPU時並開啟多線程時不如pytorch方便,pytorch占用顯存更少;

2. 定時從數據庫拿數據,改成kafka生產消費模型,代碼更簡潔,邏輯更簡單;


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM