Fenrier Lab

TensorRT 使用指南(3):自定义算子插件

自定义插件开发基本流程

TensorRT 本身只是一个推理优化框架,它需要从其他训练框架的模型中解析出网络结构,因此常常会遇到一些 TensorRT 不支持的算子,这时我们就需要用插件扩展功能来自己实现这些算子。

简单来说,开发一个 TensorRT 自定义算子插件的流程是:

  1. 首先明确要开发的算子,最好是先用 CUDA 实现;
  2. 开发插件类,实现 IPluginV2DynamicExt 接口,在这里调用前面实现的算子;
  3. 开发插件 Creator 类,实现 IPluginCreator 接口,用于创建插件实例,然后注册该 Creator 类;
  4. 编译插件项目,生成动态链接库;
  5. 在构造 engine 之前,首先加载上一步编译出来的插件动态链接库,在构造 engine 时 TensorRT 会自动找到先前注册的插件创建器。

项目结构

在介绍具体的开发过程之前,首先还是来了解一下项目结构,这有助于我们对 TensorRT 插件开发有个全局的认知

+---src
|   |---custom_op.cu
|   |---custom_op.cuh
|   |---custom_op.h
|   |---custom_op_plugin.cpp
|   |---custom_op_plugin.h
+---test
|   |---CMakeLists.txt
|   |---test.cpp
|---CMakeLists.txt

其中 src/ 目录下的 custom_op.cu 就是我们实现 CUDA 算子的地方,custom_op.h 里包含方法声明,custom_op.cuh 包含一些工具函数定义。而 customer_op_plugin.hcustomer_op_plugin.cpp 则实现了 TensorRT 插件的接口,并在其中调用 CUDA 算子。

IPluginV2DynamicExt 接口

IPluginV2DynamicExt 只是 TensorRT 提供的插件接口之一,不过由于它具有 Dynamic Shape 功能,所以目前是最实用的接口。IPluginV2DynamicExt 本身继承自 IPluginV2Ext,而 IPluginV2Ext 又继承自 IPluginV2,后两者在 TensorRT 8.5 版本之后已经被标记为过时,因此我们开发插件时应该尽量使用 IPluginV2DynamicExt 接口。

继承 IPluginV2DynamicExt 接口需要实现大量方法,这里我们挑选几个比较重要的方法介绍一下,更具体的可以参考官方示例,开源grid_sample算子trt插件项目onnxparser-trt-plugin-sample


  1. getOutputDimensions
    DimsExprs getOutputDimensions(int32_t outputIndex, 
                                         DimsExprs const* inputs, 
                                         int32_t nbInputs, 
                                         IExprBuilder& exprBuilder) noexcept override;
    

该方法根据输入维度信息计算输出张量的维度,其中 outputIndex 表示当前计算的 output 序号。比如一个卷积算子,卷积核以及 paddingstride 等信息已知并作为类变量,于是根据这里的输入维度信息就可以计算输出维度信息。


  1. supportsFormatCombination
    bool supportsFormatCombination(int32_t pos, 
                                 PluginTensorDesc const* inOut, 
                                 int32_t nbInputs, 
                                 int32_t nbOutputs) noexcept override;
    

该方法光从名称来看不是很好理解,首先来看一下它的四个参数

  • pos 表示当前查询张量序号,注意这里输入和输出是合在一起排序的,也就是说 0 < pos < nbInputs + nbOutputs,其中 nbInputs 表示输入张量的个数,nbOutputs 表示输出张量的个数。当 pos < nbInputs 时,表示当前查询的是输入张量,否则表示当前查询的是输出张量。

  • inOut 表示输入或输出张量的描述信息,其中包含了张量的维度信息,数据类型type,数据布局格式format等。

TensorRT 通过这个方法来查询 pos 所指定张量的 typeformat 的组合是否是被当前插件所支持的。type 无非就单精度、半精度、整型等等,而 format 则是指张量的布局方式,在 TensorRT 中有多种布局方式,具体可见官方开发者指南的附录: Data Format Descriptions 部分。


  1. configurePlugin
    void configurePlugin(DynamicPluginTensorDesc const* in, 
                       int32_t nbInputs,
                       DynamicPluginTensorDesc const* out, 
                       int32_t nbOutputs) noexcept override;
    

    该方法用于对插件配置输入输出相关参数,且在 engine 构建阶段和执行阶段都会被调用,原因是构建阶段和执行阶段输入输出张量的维度信息可能不同(因为是 dynamic shape 的),因此需要在每次执行前都重新配置一下。


  1. getOutputDataType
    DataType getOutputDataType(int32_t index, 
                             nvinfer1::DataType const* inputTypes, 
                             int32_t nbInputs) const noexcept override;
    

    该方法用于查询输出张量的数据类型,其中 index 表示当前查询的输出序号,inputTypes 表示输入张量的数据类型,nbInputs 表示输入张量的个数。一般来说,输出张量的数据类型和输入张量的相同,因此只需要返回 inputTypes[0] 即可。


  1. getSerializationSize
    size_t getSerializationSize() const noexcept override;
    

该方法用于查询本插件序列化需要的内存大小,实际上就是对所有当前类变量数据的字节大小求和。


  1. serialize
    void serialize(void* buffer) const noexcept override;
    

该方法用于将当前类变量数据序列化到指定内存中,也就是参数 buffer 所指向的内存。通常需要用到以下工具函数

template <typename T>
void writeToBuffer(char*& buffer, const T& val)
{
    *reinterpret_cast<T*>(buffer) = val;
    buffer += sizeof(T);
}

举个例子,假如我们有如下类变量

private:
    int32_t m_kernel_size;
    int32_t m_padding;
    int32_t m_stride;

那么在 serialize 方法中就可以这样写

char* data = reinterpret_cast<char*>(buffer);
char* start = data;
writeToBuffer(data, m_kernel_size);
writeToBuffer(data, m_padding);
writeToBuffer(data, m_stride);
assert(data == start + getSerializationSize());

  1. enqueue
    int32_t enqueue(PluginTensorDesc const* inputDesc, 
                 PluginTensorDesc const* outputDesc,
                 void const* const* inputs, 
                 void* const* outputs, 
                 void* workspace, 
                 cudaStream_t stream) noexcept override;
    

该方法是具体的插件执行方法,在这里调用 CUDA 算子。

IPluginCreator 接口

实现 IPluginCreator 接口的类负责创建插件实例,它的主要方法如下


  1. getPluginName
    const char* getPluginName() const noexcept override;
    

该方法用于返回插件的名称,需要注意的是,这个名称必须和 IPluginV2DynamicExt 接口的 getPluginType 方法的返回值保持一致。


  1. createPlugin
    IPluginV2* createPlugin(const char* name, 
                         const PluginFieldCollection* fc) noexcept override;
    

这是创建插件的主要方法,其中 name 表示插件名称,fc 表示插件类的字段集合,通过 fc -> fields 方法我们可以拿到 PluginField 指针数组,每个 PluginField 对象包含了字段名称,字段类型,字段数据等信息,通过类型转换可以得到具体的字段数据并创建插件实例。


  1. getFieldNames
    const PluginFieldCollection* getFieldNames() noexcept override;
    

该方法用于返回插件类的字段集合,它将被传给 createPlugin 方法。


  1. deserializePlugin
    IPluginV2* deserializePlugin(const char* name, 
                               const void* serialData, 
                               size_t serialLength) noexcept override;
    

该方法用于反序列化插件,其中 name 表示插件名称,serialData 表示序列化数据,serialLength 表示序列化数据的字节大小。反序列化的过程和序列化的过程相反,可以借助以下工具函数

template <typename T>
T readFromBuffer(const char*& buffer)
{
    T val = *reinterpret_cast<const T*>(buffer);
    buffer += sizeof(T);
    return val;
}

在完成 IPluginCreator 接口的实现之后,还需要将其注册到 TensorRT 中,具体注册有好几种方式,这里我们使用最简单的,一行语句搞定

REGISTER_TENSORRT_PLUGIN(CustomeOpPluginCreator);

项目编译以及 Python 端插件加载

编译过程依然是 CMake 三件套

mkdir build && cd build
cmake ..
make            ## linux
cmake --build . ## windows

当然,编译过程可能会遇到各种问题,这时候就需要在实际操作中不断排查了。编译出 .so 文件之后,就可以在 Python 端加载了,使用如下方法

def load_plugin(logger: trt.Logger):
    success = ctypes.CDLL("build/libcustom_op_plugin.so", mode = ctypes.RTLD_GLOBAL)
    if not success:
        print("load custom_op plugin error")
        raise Exception()

之后在构造 engine 时,TensorRT 将自动寻找插件创建器。

总结

本文详细介绍了 TensorRT 自定义插件开发的基本流程,总体来说就是使用 CUDA 实现具体的算法,然后在继承了插件接口的类中调用 CUDA 算子,最后将插件创建器注册到 TensorRT 中。这里我们并没有使用具体的示例来讲解,不是因为示例不重要,而是因为示例程序势必涉及到具体的算法,对算法的讲解会分散本文的重点,因此阅读本文只是一个开始,要想完整地学习 TensorRT 插件开发,还需要进一步研究官方文档和示例程序。

本文遵守 CC-BY-NC-4.0 许可协议。

Creative Commons License

欢迎转载,转载需注明出处,且禁止用于商业目的。