纸上得来终觉浅,绝知此事要躬行。参考TVM文档向Relay添加新算子( TVM对应文档:https://tvm.apache.org/docs/dev/how_to/relay_add_op.html),文档中省去了许多细节,在本文中将实践向Relay添加新算子,并将细节一一列出,给出涉及的文件地址、新增代码、测试脚本。

算子定义

LayerNorm由于其推理时在线计算均值和方差的特性,使得其运行时开销较大(LayerNorm计算特性和部署优化: https://zhuanlan.zhihu.com/p/587092648)。为了减小开销,其中一种方法是采用新的归一化方法替代LayerNorm。RMSNorm就是一个可行的研究工作。

RMSNorm论文:Zhang B, Sennrich R. Root mean square layer normalization[J]. Advances in Neural Information Processing Systems, 2019, 32

对LayerNorm成功的一个著名的解释是它的重新定心和重新缩放的不变性。前者使模型对输入和权值上的移位噪声不敏感,当输入和权值都被随机缩放时,后者保持输出表示的完整。RMSNorm论文假设重新缩放不变性是LayerNorm成功的原因,而不是重新定心不变性。

RMSNorm只关注重新缩放不变性,并根据均方根(RMS)统计量来归一化。相对于LayerNorm,删除了关于均值的统计。

RMSNorm计算公式如下:

img

定义属性节点

属性是编译时已知的固定参数。定义一个属性结构体来描述算子的属性。例如Conv2d算子,stride、padding、dilation、kernel_size等为其属性

tvm/include/tvm/relay/attrs/nn.h添加以下代码:

/*! \brief Attributes used in RMSNorm operator */
struct RMSNormAttrs : public tvm::AttrsNode<RMSNormAttrs> {
int axis;
double epsilon;
bool scale;

TVM_DECLARE_ATTRS(RMSNormAttrs, "relay.attrs.RMSNormAttrs") {
TVM_ATTR_FIELD(axis).set_default(-1).describe("Specify which shape axis denotes the channel.");
TVM_ATTR_FIELD(epsilon).set_default(1e-5).describe(
"Small float added to variance to avoid dividing by zero");
TVM_ATTR_FIELD(scale).set_default(true).describe(
"If true, multiply by gamma; otherwise, gamma is ignored.");
}
}; // struct RMSNormAttrs

编写类型关系

在编译时需要对算子的输入、输出的类型进行检查,并对算子的输入、输出类型之间的关系进行类型化。这些关系被表示为函数,它接收一个输入类型和输出类型的列表(这些类型中的任何一个都可能是不完整的),并返回一个满足关系的输入和输出类型的列表。这包括形状信息,可以在编译时静态地确定。

tvm/src/relay/op/nn/nn.cc添加以下代码:

// rms_norm  注册属性节点
TVM_REGISTER_NODE_TYPE(RMSNormAttrs);

//类型检查 形状推理
bool RMSNormRel(const Array<Type>& types, int num_inputs,const Attrs& attrs, const TypeReporter& reporter){
ICHECK_EQ(types.size(),3); // [data,gamma,output] 数量为输入+输出
const auto* data=types[0].as<TensorTypeNode>();
if(data==nullptr) return false;
const RMSNormAttrs* param = attrs.as<RMSNormAttrs>();
int axis = param->axis>=0 ? param->axis: param->axis+data->shape.size(); //axis可能是python风格的负数 如-1
ICHECK(axis >= 0 && axis < (int)data->shape.size());
reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); //gamma的shape与axis所在shape相同
reporter->Assign(types[2], TensorType({data->shape}, data->dtype));
//output的shape与data的shape相同
return true;
}

算子与属性关联

注册算子的名称及其他描述,并用调用接口进行标注。

tvm/src/relay/op/nn/nn.cc添加以下代码:

//根据输入与属性调用算子
Expr MakeRMSNorm(Expr data, Expr gamma,int axis,double epsilon,bool scale) {
auto attrs=make_object<RMSNormAttrs>();
attrs->axis=axis;
attrs->epsilon=epsilon;
attrs->scale=scale;
static const Op& op = Op::Get("nn.RMS_norm");
return Call(op,{data,gamma}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.RMS_norm").set_body_typed(MakeRMSNorm);


//注册算子
RELAY_REGISTER_OP("nn.RMS_norm")
.describe(R"code(
RMSNorm: It is a replacement of LayerNorm.
Zhang B, Sennrich R. Root mean square layer normalization[J]. Advances in Neural Information Processing Systems, 2019, 32
)code" TVM_ADD_FILELINE)
.set_attrs_type<RMSNormAttrs>()
.set_num_inputs(2)
.add_argument("data","Tensor","Input to which RMS_norm will be applied")
.add_argument("gamma","Tensor","The gamma scale factor.")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", NormalizationInferCorrectLayout<RMSNormAttrs>)
.set_support_level(1)
.add_type_rel("RMSNorm",RMSNormRel);

定义算子的计算

TVM的TOPI算子库包含多个后端的算子的计算与调度定义,在这里进行算子在Python端注册。

新增tvm/python/tvm/topi/nn/RMS_norm.py文件,编写如下代码:

利用TVM的跨语言调用机制,将RMSNorm的计算定义编写在CPP端,在Python端提供接口

"""RMS normalization operator."""
from .. import cpp


def RMS_norm(data, gamma, axis, epsilon=1e-5):
"""RMS normalization operator.

Parameters
----------
data : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})

gamma: tvm.te.Tensor
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k


axis : list of int
Axis over the normalization applied

epsilon : float
The epsilon value to avoid division by zero.

Returns
-------
result : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
return cpp.nn.RMS_norm(data, gamma, axis, epsilon)

注意在tvm/python/tvm/topi/nn/init.py导入算子

新增tvm/include/tvm/topi/nn/RMS_norm.h文件,编写如下代码:

该代码描述了RMSNorm的计算流程

/*!
* \brief RMS normalization op constructions
* \file nn/RMS_norm.h
*/
#ifndef TVM_TOPI_NN_RMS_NORM_H_
#define TVM_TOPI_NN_RMS_NORM_H_

#include <tvm/te/operation.h>
#include <tvm/topi/tags.h>

#include <string>

namespace tvm {
namespace topi {
namespace nn {

using namespace tvm::te;

/*!
* \brief RMS normalization.
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
* \param gamma K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
* d_{axis_k} == r_k
* \param axis The axis to normalize over.
* \param epsilon The epsilon value to avoid division by zero.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
* \return The normalized tensor, with the same shape as data.
*/
inline Tensor RMS_norm(const Tensor& data, const Tensor& gamma,
const Array<Integer>& axis, double epsilon,
std::string name = "T_RMS_norm", std::string tag = kInjective) {
// sum x^2
auto ndim = data->shape.size();
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_axes = MakeReduceAxes(real_axis, data);
auto target_shape =
MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true);
auto func = MakeTupleSumReducer();

auto compute = [ndim, &real_axis, &reduce_axes, &func, &data](const Array<Var>& indices) {
Array<PrimExpr> eval_range;
int arg_counter = 0;
int red_counter = 0;

for (size_t i = 0; i < ndim; ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
// real_axis contains i
eval_range.push_back(reduce_axes[red_counter]);
red_counter++;
} else {
eval_range.push_back(indices[arg_counter]);
arg_counter++;
}
}
auto square = [](const PrimExpr& x) { return x * x; };
return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr);
};

auto temp_x_x2 =
tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce);

//获得平方和
auto temp_x2 = temp_x_x2[1];

//平方和求均值时要除以的元素的数量
auto reduce_extent = make_const(data->dtype, 1);
for (int i : real_axis) {
reduce_extent *= data->shape[i];
}


auto RMS_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices;
for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
reduce_indices.push_back(indices[i]);
} else {
non_reduce_indices.push_back(indices[i]);
}
}

auto var = temp_x2(non_reduce_indices) / reduce_extent ;
auto RMS_norm = data(indices) * tvm::rsqrt(var + make_const(var->dtype, epsilon)); //tvm::rsqrt 即 1/tvm::sqrt
RMS_norm = topi::multiply(RMS_norm, gamma(reduce_indices));

return RMS_norm;
};


return tvm::te::compute(data->shape,RMS_norm_func, name, tag);
}

} // namespace nn
} // namespace topi
} // namespace tvm

#endif // TVM_TOPI_NN_RMS_NORM_H_

tvm/src/topi/nn.cc增加如下代码:

注册topi算子库的CPP端

// 注意引入下面的头文件
#include <tvm/topi/nn/RMS_norm.h>


/* Ops from nn/RMS_norm.h */
TVM_REGISTER_GLOBAL("topi.nn.RMS_norm").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::layer_norm(args[0], args[1], args[2], static_cast<double>(args[3]));
});

提供Python API

tvm/python/tvm/relay/op/nn/nn.py增加如下代码:

def RMS_norm(data,gamma,axis=-1,epsilon=1e-5,scale=True):
return _make.RMS_norm(data,gamma,axis,epsilon,scale)

编写测试文件

新增tvm/python/tvm/topi/testing/RMS_norm_python.py,编写如下代码:

该文件是用于topi算子的测试,使用numpy编写,在测试时作为标准答案

import numpy as np
from functools import reduce

def RMS_norm_python(data, gamma, axis, epsilon=1e-5):
"""RMS normalization operator in Python.

Parameters
----------
data : numpy.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})

gamma: numpy.ndarray
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k


axis : int or tuple of ints
Axis over the normalization applied

epsilon : float
The epsilon value to avoid division by zero.

Returns
-------
result : np.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})
"""

if len(axis)==1:
n=data.shape[axis[0]]
else:
n=reduce(lambda x,y:data.shape[x]*data.shape[y],axis)

temp=np.sum(np.square(data),axis) / n
temp=np.repeat(temp,axis=0,repeats=n).reshape(data.shape)
result = data/ np.sqrt(temp + epsilon)
result *= gamma

return result

注意在tvm/python/tvm/topi/testing/init.py导入,导入代码如下

from .RMS_norm_python import RMS_norm_python

新增tvm/tests/python/topi/python/test_topi_RMS_norm.py,编写如下代码

"""Test code for RMS_norm."""
import numpy as np
import pytest
import tvm
from tvm import te
from tvm import topi
from tvm.topi.utils import get_const_tuple
import tvm.topi.testing

import tvm.testing


# 使用通用的injective调度
_RMS_norm_schedule = {
"generic": topi.generic.schedule_injective,
}


# 对最后一维和最后两维分别测试
@tvm.testing.parametrize_targets("llvm")
@pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))])
def test_layer_norm(target, dev, shape, axis, episilon=1e-5, dtype="float32", rtol=1e-5, atol=1e-5):
data = te.placeholder(shape, dtype=dtype, name="data")
scale_shape = [shape[dim] for dim in axis]
gamma = te.placeholder(scale_shape, dtype=dtype, name="gamma")
B = topi.nn.RMS_norm(data, gamma, axis, episilon)# 调用TOPI算子库中的RMSNorm

data_np = np.random.uniform(size=shape).astype(dtype)
gamma_np = np.random.uniform(size=scale_shape).astype(dtype)
beta_np = np.random.uniform(size=scale_shape).astype(dtype)
b_np = tvm.topi.testing.RMS_norm_python(data_np, gamma_np, axis, episilon) # 调用numpy编写的RMSNorm作为标准答案

with tvm.target.Target(target):
s_func = tvm.topi.testing.dispatch(target, _RMS_norm_schedule)
s = s_func([B])
data_tvm = tvm.nd.array(data_np, dev)
gamma_tvm = tvm.nd.array(gamma_np, dev)
b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev)
f = tvm.build(s, [data, gamma, B], target)
f(data_tvm, gamma_tvm, b_tvm)
tvm.testing.assert_allclose(b_tvm.asnumpy(), b_np, rtol=rtol, atol=atol)


if __name__ == "__main__":
tvm.testing.main()

执行测试

由于涉及到CPP代码修改,因此需要重新make整个项目生成更新的动态链接库。

之后执行tvm/tests/python/topi/python/test_topi_RMS_norm.py测试文件,得到以下结果:表示测试成功

img

总结

TVM中算子的注册并不局限这一种方法,例如对于BatchNorm,并没有在CPP端描述其计算,而是在python端定义(tvm/topi/nn/batch_norm.py),对于其调度,也有不同后端的多个实现。

本文只包含了编译器中新增算子,后续将更新在Pytorch中定义RMSNorm、导出为ONNX、利用Relay前端导入的流程。(挖坑:)