TensorRT&Sample&Python[fc_plugin_caffe_mnist]



本文是基於TensorRT 5.0.2基礎上,關於其內部的fc_plugin_caffe_mnist例子的分析和介紹。
本例子相較於前面例子的不同在於,其還包含cpp代碼,且此時依賴項還挺多。該例子展示如何使用基於cpp寫的plugin,用tensorrt python 綁定接口和caffe解析器一起工作的過程。該例子使用cuBLAS和cuDNn實現一個全連接層,然后實現成tensorrt plugin,然后用pybind11生成對應python綁定,這些綁定隨后被用來注冊為caffe解析器的一部分。

1 引言

假設當前路徑為:

TensorRT-5.0.2.6/samples

其對應當前例子文件目錄樹為:

# tree python

python
├── common.py
├── fc_plugin_caffe_mnist
│   ├── CMakeLists.txt
│   ├── __init__.py
│   ├── plugin
│   │   ├── FullyConnected.h
│   │   └── pyFullyConnected.cpp
│   ├── README.md
│   ├── requirements.txt
│   └── sample.py

其中:

  • plugin包含FullyConnected 層的plugin:
  • FullyConnected.h 基於CUDA,cuDNN,cuBLAS實現該插件;
  • pyFullyConnected.cpp 生成關於FCPlugin和FCPluginFactory插件的python綁定;
  • sample.py 使用提供的FullyConnected 層插件運行MNIST網絡;

2 安裝依賴

git clone -b v2.2.3 https://github.com/pybind/pybind11.git
  • 安裝python包:
Pillow
pycuda
numpy
argparse

3 編譯該插件

  • 創建build文件夾,然后進入該文件夾
mkdir build && pushd build
  • cmake生成對應Makefile,此處可以自由設定一些參數。如果其中有些依賴不在默認位置路徑上,可以cmake手動指定,關於Cmake的文檔,可參考
cmake .. -DCUDA_ROOT=/usr/local/cuda-9.0 \
         -DPYBIND11_DIR=/root/pybind11/ \
         -DPYTHON3_INC_DIR=/usr/local/python3/include/python3.5m/ \
         -DNVINFER_LIB=/TensorRT-5.0.2.6/lib/libnvinfer.so  \
         -D_NVINFER_PLUGIN_LIB=/TensorRT-5.0.2.6/lib/ \
         -D_NVPARSERS_LIB=/TensorRT-5.0.2.6/lib \
         -DTRT_INC_DIR=/TensorRT-5.0.2.6/include/

注意cmake打出的日志中的VARIABLE_NAME-NOTFOUND

  • 進行編譯
make -j32
  • 跳出build
popd

4 代碼解析

首先,按上面編譯過程所述,在build文件夾中會需要調用cmake命令,而該命令會讀取上一層,也就是CMakeLists.txt,
其中關於find_library, include_directories, add_subdirectory的可以參考cmake-command文檔

cmake_minimum_required(VERSION 3.2 FATAL_ERROR) # 最小cmake版本限定
project(FCPlugin LANGUAGES CXX C) # 項目名稱和對應的編程語言

# 設定一個宏set_ifndef,用於操作當變量未找到時的行為:此處將未找到變量var 設定為val
macro(set_ifndef var val)
    if(NOT ${var})
        set(${var} ${val})
    endif()
    message(STATUS "Configurable variable ${var} set to ${${var}}")
endmacro()

# -------- CONFIGURATION --------
# Set module name here. MUST MATCH the module name specified in the .cpp
set_ifndef(PY_MODULE_NAME fcplugin) 
set(CMAKE_CXX_STANDARD 11) # 設定C++11標注
set(PYBIND11_CPP_STANDARD -std=c++11) # pybind11 defaults to c++14.

set_ifndef(PYBIND11_DIR $ENV{HOME}/pybind11/)
set_ifndef(CUDA_VERSION 10.0)
set_ifndef(CUDA_ROOT /usr/local/cuda-${CUDA_VERSION})
set_ifndef(CUDNN_ROOT ${CUDA_ROOT})
set_ifndef(PYTHON_ROOT /usr/include)
set_ifndef(TRT_LIB_DIR /usr/lib/x86_64-linux-gnu)
set_ifndef(TRT_INC_DIR /usr/include/x86_64-linux-gnu)

# 尋找依賴
message("\nThe following variables are derived from the values of the previous variables unless provided explicitly:\n")

find_path(_CUDA_INC_DIR cuda_runtime_api.h HINTS ${CUDA_ROOT} PATH_SUFFIXES include)
set_ifndef(CUDA_INC_DIR ${_CUDA_INC_DIR})

find_library(_CUDA_LIB cudart HINTS ${CUDA_ROOT} PATH_SUFFIXES lib lib64)
set_ifndef(CUDA_LIB ${_CUDA_LIB})

find_library(_CUBLAS_LIB cublas HINTS ${CUDA_ROOT} PATH_SUFFIXES lib lib64)
set_ifndef(CUBLAS_LIB ${_CUBLAS_LIB})

find_path(_CUDNN_INC_DIR cudnn.h HINTS ${CUDNN_ROOT} PATH_SUFFIXES include x86_64-linux-gnu)
set_ifndef(CUDNN_INC_DIR ${_CUDNN_INC_DIR})

find_library(_CUDNN_LIB cudnn HINTS ${CUDNN_ROOT} PATH_SUFFIXES lib lib64 x86_64-linux-gnu)
set_ifndef(CUDNN_LIB ${_CUDNN_LIB})

find_library(_TRT_INC_DIR NvInfer.h HINTS ${TRT_INC_DIR} PATH_SUFFIXES include x86_64-linux-gnu)
set_ifndef(TRT_INC_DIR ${_TRT_INC_DIR})

find_library(_NVINFER_LIB nvinfer HINTS ${TRT_LIB_DIR} PATH_SUFFIXES lib lib64 x86_64-linux-gnu)
set_ifndef(NVINFER_LIB ${_NVINFER_LIB})

find_library(_NVPARSERS_LIB nvparsers HINTS ${TRT_LIB_DIR} PATH_SUFFIXES lib lib64 x86_64-linux-gnu)
set_ifndef(NVPARSERS_LIB ${_NVPARSERS_LIB})

find_library(_NVINFER_PLUGIN_LIB nvinfer_plugin HINTS ${TRT_LIB_DIR} PATH_SUFFIXES lib lib64 x86_64-linux-gnu)
set_ifndef(NVINFER_PLUGIN_LIB ${_NVINFER_PLUGIN_LIB})

find_path(_PYTHON2_INC_DIR Python.h HINTS ${PYTHON_ROOT} PATH_SUFFIXES python2.7)
set_ifndef(PYTHON2_INC_DIR ${_PYTHON2_INC_DIR})

find_path(_PYTHON3_INC_DIR Python.h HINTS ${PYTHON_ROOT} PATH_SUFFIXES python3.7 python3.6 python3.5 python3.4)
set_ifndef(PYTHON3_INC_DIR ${_PYTHON3_INC_DIR})

# -------- BUILDING --------

# 增加include文件夾路徑
include_directories(${TRT_INC_DIR} ${CUDA_INC_DIR} ${CUDNN_INC_DIR} ${PYBIND11_DIR}/include/)

# CMAKE_BINARY_DIR:表示build的根路徑,這里是在build文件夾增加pybind11文件夾
add_subdirectory(${PYBIND11_DIR} ${CMAKE_BINARY_DIR}/pybind11)

# CMAKE_SOURCE_DIR:表示項目的根路徑
file(GLOB_RECURSE SOURCE_FILES ${CMAKE_SOURCE_DIR}/plugin/*.cpp)

# Bindings library. The module name MUST MATCH the module name specified in the .cpp
# 是否支持python3
if(PYTHON3_INC_DIR AND NOT (${PYTHON3_INC_DIR} STREQUAL "None"))
    pybind11_add_module(${PY_MODULE_NAME} SHARED THIN_LTO ${SOURCE_FILES})
    target_include_directories(${PY_MODULE_NAME} BEFORE PUBLIC ${PYTHON3_INC_DIR})
    target_link_libraries(${PY_MODULE_NAME} PRIVATE ${CUDNN_LIB} ${CUDA_LIB} ${CUBLAS_LIB} ${NVINFER_LIB} ${NVPARSERS_LIB} ${NVINFER_PLUGIN_LIB})
endif()

# 是否支持python2
if(PYTHON2_INC_DIR AND NOT (${PYTHON2_INC_DIR} STREQUAL "None"))
    # Suffix the cmake target name with a 2 to differentiate from the Python 3 bindings target.
    pybind11_add_module(${PY_MODULE_NAME}2 SHARED THIN_LTO ${SOURCE_FILES})
    target_include_directories(${PY_MODULE_NAME}2 BEFORE PUBLIC ${PYTHON2_INC_DIR})
    target_link_libraries(${PY_MODULE_NAME}2 PRIVATE ${CUDNN_LIB} ${CUDA_LIB} ${CUBLAS_LIB} ${NVINFER_LIB} ${NVPARSERS_LIB} ${NVINFER_PLUGIN_LIB})
    # Rename to remove the .cpython-35... extension.
    set_target_properties(${PY_MODULE_NAME}2 PROPERTIES OUTPUT_NAME ${PY_MODULE_NAME} SUFFIX ".so")
    # Python 2 requires an empty __init__ file to be able to import.
    file(WRITE ${CMAKE_BINARY_DIR}/__init__.py "")
endif()

運行結果如圖:

現在來看FullyConnected.h,因為長期不寫cpp,所以對cpp代碼都生疏了

#ifndef _FULLY_CONNECTED_H_
#define _FULLY_CONNECTED_H_

#include <cassert>
#include <cstring>
#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <cublas_v2.h>
#include <stdexcept>

#include "NvInfer.h" //在路徑 /TensorRT-5.0.2.6/include/
#include "NvCaffeParser.h" //在路徑 /TensorRT-5.0.2.6/include/

#define CHECK(status) { if (status != 0) throw std::runtime_error(__FILE__ +  __LINE__ + std::string{"CUDA Error: "} + std::to_string(status)); }

// 將數據從host移動到device
nvinfer1::Weights copyToDevice(const void* hostData, int count)
{
	void* deviceData;
	CHECK(cudaMalloc(&deviceData, count * sizeof(float)));
	CHECK(cudaMemcpy(deviceData, hostData, count * sizeof(float), cudaMemcpyHostToDevice));
	return nvinfer1::Weights{nvinfer1::DataType::kFLOAT, deviceData, count};
}

//將數據從device移動到host
int copyFromDevice(char* hostBuffer, nvinfer1::Weights deviceWeights)
{
	*reinterpret_cast<int*>(hostBuffer) = deviceWeights.count;
	CHECK(cudaMemcpy(hostBuffer + sizeof(int), deviceWeights.values, deviceWeights.count * sizeof(float), cudaMemcpyDeviceToHost));
	return sizeof(int) + deviceWeights.count * sizeof(float);
}
//-----------------------------

/*建立FCPlugin類*/
class FCPlugin: public nvinfer1::IPluginExt
{
public:
	// In this simple case we're going to infer the number of output channels from the bias weights.
	// The knowledge that the kernel weights are weights[0] and the bias weights are weights[1] was
	// divined from the caffe innards
	FCPlugin(const nvinfer1::Weights* weights, int nbWeights)
	{
		assert(nbWeights == 2);
		mKernelWeights = copyToDevice(weights[0].values, weights[0].count);
		mBiasWeights = copyToDevice(weights[1].values, weights[1].count);
	}

	// 構造函數,用於從一個字節流中創建plugin
	FCPlugin(const void* data, size_t length)
	{
		const char* d = reinterpret_cast<const char*>(data), *a = d;
		mKernelWeights = copyToDevice(d + sizeof(int), reinterpret_cast<const int&>(d));
		d += sizeof(int) + mKernelWeights.count * sizeof(float);
		mBiasWeights = copyToDevice(d + sizeof(int), reinterpret_cast<const int&>(d));
		d += sizeof(int) + mBiasWeights.count * sizeof(float);
		assert(d == a + length);
	}

	virtual int getNbOutputs() const override { return 1; }

	virtual nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override
	{
		assert(index == 0 && nbInputDims == 1 && inputs[0].nbDims == 3);
		return nvinfer1::DimsCHW{static_cast<int>(mBiasWeights.count), 1, 1};
	}

	virtual int initialize() override
	{
		CHECK(cudnnCreate(&mCudnn));
		CHECK(cublasCreate(&mCublas));
		// Create cudnn tensor descriptors for bias addition.
		CHECK(cudnnCreateTensorDescriptor(&mSrcDescriptor));
		CHECK(cudnnCreateTensorDescriptor(&mDstDescriptor));
		return 0;
	}

	virtual void terminate() override
	{
		CHECK(cudnnDestroyTensorDescriptor(mSrcDescriptor));
		CHECK(cudnnDestroyTensorDescriptor(mDstDescriptor));
		CHECK(cublasDestroy(mCublas));
		CHECK(cudnnDestroy(mCudnn));
	}

    // This plugin requires no workspace memory during build time.
	virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }

	virtual int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override
	{
		int nbOutputChannels = mBiasWeights.count;
		int nbInputChannels = mKernelWeights.count / nbOutputChannels;
		constexpr float kONE = 1.0f, kZERO = 0.0f;
		// Do matrix multiplication.
		cublasSetStream(mCublas, stream);
		cudnnSetStream(mCudnn, stream);
		CHECK(cublasSgemm(mCublas, CUBLAS_OP_T, CUBLAS_OP_N, nbOutputChannels, batchSize, nbInputChannels, &kONE,
				reinterpret_cast<const float*>(mKernelWeights.values), nbInputChannels,
				reinterpret_cast<const float*>(inputs[0]), nbInputChannels, &kZERO,
				reinterpret_cast<float*>(outputs[0]), nbOutputChannels));
        // Add bias.
		CHECK(cudnnSetTensor4dDescriptor(mSrcDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, nbOutputChannels, 1, 1));
		CHECK(cudnnSetTensor4dDescriptor(mDstDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batchSize, nbOutputChannels, 1, 1));
		CHECK(cudnnAddTensor(mCudnn, &kONE, mSrcDescriptor, mBiasWeights.values, &kONE, mDstDescriptor, outputs[0]));
		return 0;
	}

	// For this sample, we'll only support float32 with NCHW.
	virtual bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const override
	{
		return (type == nvinfer1::DataType::kFLOAT && format == nvinfer1::PluginFormat::kNCHW);
	}

	void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs, const nvinfer1::Dims* outputDims, int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize)
	{
		assert(nbInputs == 1 && inputDims[0].d[1] == 1 && inputDims[0].d[2] == 1);
		assert(nbOutputs == 1 && outputDims[0].d[1] == 1 && outputDims[0].d[2] == 1);
		assert(mKernelWeights.count == inputDims[0].d[0] * inputDims[0].d[1] * inputDims[0].d[2] * mBiasWeights.count);
	}

	virtual size_t getSerializationSize() override
	{
		return sizeof(int) * 2 + mKernelWeights.count * sizeof(float) + mBiasWeights.count * sizeof(float);
	}

	virtual void serialize(void* buffer) override
	{
		char* d = reinterpret_cast<char*>(buffer), *a = d;
		d += copyFromDevice(d, mKernelWeights);
		d += copyFromDevice(d, mBiasWeights);
		assert(d == a + getSerializationSize());
	}

	// 析構函數,釋放buffer.
	virtual ~FCPlugin()
	{
		cudaFree(const_cast<void*>(mKernelWeights.values));
		mKernelWeights.values = nullptr;
		cudaFree(const_cast<void*>(mBiasWeights.values));
		mBiasWeights.values = nullptr;
	}

private:
	cudnnHandle_t mCudnn;
	cublasHandle_t mCublas;
	nvinfer1::Weights mKernelWeights{nvinfer1::DataType::kFLOAT, nullptr}, mBiasWeights{nvinfer1::DataType::kFLOAT, nullptr};
	cudnnTensorDescriptor_t mSrcDescriptor, mDstDescriptor;
};


/*建立FCPluginFactory類*/
class FCPluginFactory : public nvcaffeparser1::IPluginFactoryExt, public nvinfer1::IPluginFactory
{
public:
	bool isPlugin(const char* name) override { return isPluginExt(name); }

	bool isPluginExt(const char* name) override { return !strcmp(name, "ip2"); }

    // Create a plugin using provided weights.
	virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const nvinfer1::Weights* weights, int nbWeights) override
	{
		assert(isPluginExt(layerName) && nbWeights == 2);
		assert(mPlugin == nullptr);
        // This plugin will need to be manually destroyed after parsing the network, by calling destroyPlugin.
		mPlugin = new FCPlugin{weights, nbWeights};
		return mPlugin;
	}

    // Create a plugin from serialized data.
	virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) override
	{
		assert(isPlugin(layerName));
        // This will be automatically destroyed when the engine is destroyed.
		return new FCPlugin{serialData, serialLength};
	}

    // User application destroys plugin when it is safe to do so.
    // Should be done after consumers of plugin (like ICudaEngine) are destroyed.
	void destroyPlugin() { delete mPlugin; }

    FCPlugin* mPlugin{ nullptr };
};

#endif //_FULLY_CONNECTED_H

現在來看pyFullyConnected.cpp該源碼中用到了pybind11,關於其文檔

#include "FullyConnected.h"
#include "NvInfer.h"
#include "NvCaffeParser.h"
#include <pybind11/pybind11.h>

PYBIND11_MODULE(fcplugin, m)
{
    namespace py = pybind11;

    // 以python方式導入tensorrt模塊.
    py::module::import("tensorrt");

    // Note that we only need to bind the constructors manually. Since all other methods override IPlugin functionality, they will be automatically available in the python bindings.
    // The `std::unique_ptr<FCPlugin, py::nodelete>` specifies that Python is not responsible for destroying the object. This is required because the destructor is private.
    py::class_<FCPlugin, nvinfer1::IPluginExt, std::unique_ptr<FCPlugin, py::nodelete>>(m, "FCPlugin")
        // Bind the normal constructor as well as the one which deserializes the plugin
        .def(py::init<const nvinfer1::Weights*, int>())
        .def(py::init<const void*, size_t>())
    ;

    // Since the createPlugin function overrides IPluginFactory functionality, we do not need to explicitly bind it here.
    // We specify py::multiple_inheritance because we have not explicitly specified nvinfer1::IPluginFactory as a base class.
    py::class_<FCPluginFactory, nvcaffeparser1::IPluginFactoryExt>(m, "FCPluginFactory", py::multiple_inheritance())
        // Bind the default constructor.
        .def(py::init<>())
        // The destroy_plugin function does not override the base class, so we must bind it explicitly.
        .def("destroy_plugin", &FCPluginFactory::destroyPlugin)
    ;
}

cpp的代碼就先不解釋了。。。
接着分析sample.py

# This sample uses a Caffe model along with a custom plugin to create a TensorRT engine.
from random import randint
from PIL import Image
import numpy as np

import pycuda.driver as cuda
import pycuda.autoinit

import tensorrt as trt

try:
    from build import fcplugin
except ImportError as err:
    raise ImportError("""ERROR: Failed to import module ({})
Please build the FullyConnected sample plugin.
For more information, see the included README.md
Note that Python 2 requires the presence of `__init__.py` in the build folder""".format(err))

import sys, os
sys.path.insert(1, os.path.join(sys.path[0], ".."))
# import common
# 這里將common中的GiB和find_sample_data,do_inference等函數移動到該py文件中,保證自包含。
def GiB(val):
    '''以GB為單位,計算所需要的存儲值,向左位移10bit表示KB,20bit表示MB '''
    return val * 1 << 30

def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]):
    '''該函數就是一個參數解析函數。
    Parses sample arguments.
    Args:
        description (str): Description of the sample.
        subfolder (str): The subfolder containing data relevant to this sample
        find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
    Returns:
        str: Path of data directory.
    Raises:
        FileNotFoundError
    '''
    # 為了簡潔,這里直接將路徑硬編碼到代碼中。
    data_root = kDEFAULT_DATA_ROOT = os.path.abspath("/TensorRT-5.0.2.6/python/data/")

    subfolder_path = os.path.join(data_root, subfolder)
    if not os.path.exists(subfolder_path):
        print("WARNING: " + subfolder_path + " does not exist. Using " + data_root + " instead.")
    data_path = subfolder_path if os.path.exists(subfolder_path) else data_root

    if not (os.path.exists(data_path)):
        raise FileNotFoundError(data_path + " does not exist.")

    for index, f in enumerate(find_files):
        find_files[index] = os.path.abspath(os.path.join(data_path, f))
        if not os.path.exists(find_files[index]):
            raise FileNotFoundError(find_files[index] + " does not exist. ")

    if find_files:
        return data_path, find_files
    else:
        return data_path
#-----------------

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

class ModelData(object):
    INPUT_NAME = "input"
    INPUT_SHAPE = (1, 28, 28)
    OUTPUT_NAME = "prob"
    OUTPUT_SHAPE = (10, )
    DTYPE = trt.float32


# 用一個解析器從binary_proto中檢索mean data.
def retrieve_mean(mean_proto):
    with trt.CaffeParser() as parser:
        return parser.parse_binary_proto(mean_proto)

# 創建解析器的plugin factory. 設定成全局是因為可以在engine銷毀之后再銷毀.
fc_factory = fcplugin.FCPluginFactory()


'''main第二步:構建engine '''
def build_engine(deploy_file, model_file):

    with trt.Builder(TRT_LOGGER) as builder, \
         builder.create_network() as network, \
         trt.CaffeParser() as parser:

        builder.max_workspace_size = GiB(1)

        # 設定解析器的plugin factory。這里將其綁定到引用是為了后續能夠手動銷毀
        # parser.plugin_factory_ext 是一個 write-only屬性
        ''' plugin_factory_ext是CaffeParser特有的接口,為了接入用戶定義的組件
       https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/parsers/Caffe/pyCaffe.html?highlight=plugin_factory_ext
        '''
        parser.plugin_factory_ext = fc_factory

        # 解析該模型,並構建engine
        model_tensors = parser.parse(deploy=deploy_file, model=model_file, network=network, dtype=ModelData.DTYPE)

        # 標記網絡的輸出
        network.mark_output(model_tensors.find(ModelData.OUTPUT_NAME))

        return builder.build_cuda_engine(network)


'''main中第三步:分配buffer '''
def allocate_buffers(engine):

    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()

    for binding in engine:

        size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
        dtype = trt.nptype(engine.get_binding_dtype(binding))

        # 分配host和device端的buffer
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)

        # 將device端的buffer追加到device的bindings.
        bindings.append(int(device_mem))

        # Append to the appropriate list.
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))

    return inputs, outputs, bindings, stream


'''main中第四步:選擇測試樣本 '''
def load_normalized_test_case(data_path, pagelocked_buffer, mean, case_num=randint(0, 9)):

    test_case_path = os.path.join(data_path, str(case_num) + ".pgm")

    # Flatten圖像為1維數組,然后歸一化,並copy到pagelocked內存中。
    img = np.array(Image.open(test_case_path)).ravel()
    np.copyto(pagelocked_buffer, img - mean)

    return case_num


'''main中第五步:執行inference '''
# 該函數可以適應多個輸入/輸出;輸入和輸出格式為HostDeviceMem對象組成的列表
def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):

    # 將數據移動到GPU
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]

    # 執行inference.
    context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)

    # 將結果從 GPU寫回到host端
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]

    # 同步stream
    stream.synchronize()

    # 返回host端的輸出結果
    return [out.host for out in outputs]


def main():

    ''' 1 - 讀取caffe生成的模型文件'''
    data_path, [deploy_file, model_file, mean_proto] = find_sample_data(
          description="Runs an MNIST network using a Caffe model file", 
          subfolder="mnist", 
         find_files=["mnist.prototxt",
                 "mnist.caffemodel", 
                 "mnist_mean.binaryproto"])

    ''' 2 - 基於build_engine構建engine'''
    with build_engine(deploy_file, model_file) as engine:

        ''' 3 - 構建engine, 分配buffers, 創建一個流 '''
        inputs, outputs, bindings, stream = allocate_buffers(engine)
        mean = retrieve_mean(mean_proto)

        with engine.create_execution_context() as context:

            ''' 4 - 讀取測試樣本,並歸一化'''
            case_num = load_normalized_test_case(data_path, inputs[0].host, mean)

            ''' 5 -執行inference,do_inference函數會返回一個list類型,此處只有一個元素 '''
            [output] = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
            pred = np.argmax(output)

            print("Test Case: " + str(case_num))
            print("Prediction: " + str(pred))

    ''' 6 - 在engine銷毀之后,這里手動銷毀plugin'''
    fc_factory.destroy_plugin()


if __name__ == "__main__":
    main()

運行結果如圖:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM