TensorRT 使用指南(3):自定义算子插件
自定义插件开发基本流程
TensorRT 本身只是一个推理优化框架,它需要从其他训练框架的模型中解析出网络结构,因此常常会遇到一些 TensorRT 不支持的算子,这时我们就需要用插件扩展功能来自己实现这些算子。
简单来说,开发一个 TensorRT 自定义算子插件的流程是:
- 首先明确要开发的算子,最好是先用 CUDA 实现;
- 开发插件类,实现
IPluginV2DynamicExt
接口,在这里调用前面实现的算子; - 开发插件 Creator 类,实现
IPluginCreator
接口,用于创建插件实例,然后注册该 Creator 类; - 编译插件项目,生成动态链接库;
- 在构造 engine 之前,首先加载上一步编译出来的插件动态链接库,在构造 engine 时 TensorRT 会自动找到先前注册的插件创建器。
项目结构
在介绍具体的开发过程之前,首先还是来了解一下项目结构,这有助于我们对 TensorRT 插件开发有个全局的认知
+---src
| |---custom_op.cu
| |---custom_op.cuh
| |---custom_op.h
| |---custom_op_plugin.cpp
| |---custom_op_plugin.h
+---test
| |---CMakeLists.txt
| |---test.cpp
|---CMakeLists.txt
其中 src/
目录下的 custom_op.cu
就是我们实现 CUDA 算子的地方,custom_op.h
里包含方法声明,custom_op.cuh
包含一些工具函数定义。而 customer_op_plugin.h
和 customer_op_plugin.cpp
则实现了 TensorRT 插件的接口,并在其中调用 CUDA 算子。
IPluginV2DynamicExt 接口
IPluginV2DynamicExt
只是 TensorRT 提供的插件接口之一,不过由于它具有 Dynamic Shape 功能,所以目前是最实用的接口。IPluginV2DynamicExt
本身继承自 IPluginV2Ext
,而 IPluginV2Ext
又继承自 IPluginV2
,后两者在 TensorRT 8.5 版本之后已经被标记为过时,因此我们开发插件时应该尽量使用 IPluginV2DynamicExt
接口。
继承 IPluginV2DynamicExt
接口需要实现大量方法,这里我们挑选几个比较重要的方法介绍一下,更具体的可以参考官方示例,开源grid_sample算子trt插件项目onnxparser-trt-plugin-sample
- getOutputDimensions
DimsExprs getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, IExprBuilder& exprBuilder) noexcept override;
该方法根据输入维度信息计算输出张量的维度,其中 outputIndex
表示当前计算的 output 序号。比如一个卷积算子,卷积核以及 padding
,stride
等信息已知并作为类变量,于是根据这里的输入维度信息就可以计算输出维度信息。
- supportsFormatCombination
bool supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override;
该方法光从名称来看不是很好理解,首先来看一下它的四个参数
-
pos 表示当前查询张量序号,注意这里输入和输出是合在一起排序的,也就是说
0 < pos < nbInputs + nbOutputs
,其中nbInputs
表示输入张量的个数,nbOutputs
表示输出张量的个数。当pos < nbInputs
时,表示当前查询的是输入张量,否则表示当前查询的是输出张量。 -
inOut 表示输入或输出张量的描述信息,其中包含了张量的维度信息,数据类型
type
,数据布局格式format
等。
TensorRT 通过这个方法来查询 pos
所指定张量的 type
和 format
的组合是否是被当前插件所支持的。type
无非就单精度、半精度、整型等等,而 format
则是指张量的布局方式,在 TensorRT 中有多种布局方式,具体可见官方开发者指南的附录: Data Format Descriptions 部分。
- configurePlugin
void configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override;
该方法用于对插件配置输入输出相关参数,且在 engine 构建阶段和执行阶段都会被调用,原因是构建阶段和执行阶段输入输出张量的维度信息可能不同(因为是 dynamic shape 的),因此需要在每次执行前都重新配置一下。
- getOutputDataType
DataType getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override;
该方法用于查询输出张量的数据类型,其中
index
表示当前查询的输出序号,inputTypes
表示输入张量的数据类型,nbInputs
表示输入张量的个数。一般来说,输出张量的数据类型和输入张量的相同,因此只需要返回inputTypes[0]
即可。
- getSerializationSize
size_t getSerializationSize() const noexcept override;
该方法用于查询本插件序列化需要的内存大小,实际上就是对所有当前类变量数据的字节大小求和。
- serialize
void serialize(void* buffer) const noexcept override;
该方法用于将当前类变量数据序列化到指定内存中,也就是参数 buffer
所指向的内存。通常需要用到以下工具函数
template <typename T>
void writeToBuffer(char*& buffer, const T& val)
{
*reinterpret_cast<T*>(buffer) = val;
buffer += sizeof(T);
}
举个例子,假如我们有如下类变量
private:
int32_t m_kernel_size;
int32_t m_padding;
int32_t m_stride;
那么在 serialize
方法中就可以这样写
char* data = reinterpret_cast<char*>(buffer);
char* start = data;
writeToBuffer(data, m_kernel_size);
writeToBuffer(data, m_padding);
writeToBuffer(data, m_stride);
assert(data == start + getSerializationSize());
- enqueue
int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
该方法是具体的插件执行方法,在这里调用 CUDA 算子。
IPluginCreator 接口
实现 IPluginCreator
接口的类负责创建插件实例,它的主要方法如下
- getPluginName
const char* getPluginName() const noexcept override;
该方法用于返回插件的名称,需要注意的是,这个名称必须和 IPluginV2DynamicExt
接口的 getPluginType
方法的返回值保持一致。
- createPlugin
IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override;
这是创建插件的主要方法,其中 name
表示插件名称,fc
表示插件类的字段集合,通过 fc -> fields
方法我们可以拿到 PluginField
指针数组,每个 PluginField
对象包含了字段名称,字段类型,字段数据等信息,通过类型转换可以得到具体的字段数据并创建插件实例。
- getFieldNames
const PluginFieldCollection* getFieldNames() noexcept override;
该方法用于返回插件类的字段集合,它将被传给 createPlugin
方法。
- deserializePlugin
IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override;
该方法用于反序列化插件,其中 name
表示插件名称,serialData
表示序列化数据,serialLength
表示序列化数据的字节大小。反序列化的过程和序列化的过程相反,可以借助以下工具函数
template <typename T>
T readFromBuffer(const char*& buffer)
{
T val = *reinterpret_cast<const T*>(buffer);
buffer += sizeof(T);
return val;
}
在完成 IPluginCreator
接口的实现之后,还需要将其注册到 TensorRT 中,具体注册有好几种方式,这里我们使用最简单的,一行语句搞定
REGISTER_TENSORRT_PLUGIN(CustomeOpPluginCreator);
项目编译以及 Python 端插件加载
编译过程依然是 CMake 三件套
mkdir build && cd build
cmake ..
make ## linux
cmake --build . ## windows
当然,编译过程可能会遇到各种问题,这时候就需要在实际操作中不断排查了。编译出 .so
文件之后,就可以在 Python 端加载了,使用如下方法
def load_plugin(logger: trt.Logger):
success = ctypes.CDLL("build/libcustom_op_plugin.so", mode = ctypes.RTLD_GLOBAL)
if not success:
print("load custom_op plugin error")
raise Exception()
之后在构造 engine 时,TensorRT 将自动寻找插件创建器。
总结
本文详细介绍了 TensorRT 自定义插件开发的基本流程,总体来说就是使用 CUDA 实现具体的算法,然后在继承了插件接口的类中调用 CUDA 算子,最后将插件创建器注册到 TensorRT 中。这里我们并没有使用具体的示例来讲解,不是因为示例不重要,而是因为示例程序势必涉及到具体的算法,对算法的讲解会分散本文的重点,因此阅读本文只是一个开始,要想完整地学习 TensorRT 插件开发,还需要进一步研究官方文档和示例程序。