//注册算子 RELAY_REGISTER_OP("nn.RMS_norm") .describe(R"code( RMSNorm: It is a replacement of LayerNorm. Zhang B, Sennrich R. Root mean square layer normalization[J]. Advances in Neural Information Processing Systems, 2019, 32 )code" TVM_ADD_FILELINE) .set_attrs_type<RMSNormAttrs>() .set_num_inputs(2) .add_argument("data","Tensor","Input to which RMS_norm will be applied") .add_argument("gamma","Tensor","The gamma scale factor.") .set_attr<FInferCorrectLayout>("FInferCorrectLayout", NormalizationInferCorrectLayout<RMSNormAttrs>) .set_support_level(1) .add_type_rel("RMSNorm",RMSNormRel);
定义算子的计算
TVM的TOPI算子库包含多个后端的算子的计算与调度定义,在这里进行算子在Python端注册。
新增tvm/python/tvm/topi/nn/RMS_norm.py文件,编写如下代码:
利用TVM的跨语言调用机制,将RMSNorm的计算定义编写在CPP端,在Python端提供接口
"""RMS normalization operator.""" from .. import cpp
/*! * \brief RMS normalization. * \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}] * \param gamma K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and * d_{axis_k} == r_k * \param axis The axis to normalize over. * \param epsilon The epsilon value to avoid division by zero. * \param name The name of the operation. * \param tag The tag to mark the operation. * \return The normalized tensor, with the same shape as data. */ inline Tensor RMS_norm(const Tensor& data, const Tensor& gamma, const Array<Integer>& axis, double epsilon, std::string name = "T_RMS_norm", std::string tag = kInjective) { // sum x^2 auto ndim = data->shape.size(); ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast<int>(ndim), axis); auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true); auto func = MakeTupleSumReducer();
auto compute = [ndim, &real_axis, &reduce_axes, &func, &data](const Array<Var>& indices) { Array<PrimExpr> eval_range; int arg_counter = 0; int red_counter = 0;
for (size_t i = 0; i < ndim; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { // real_axis contains i eval_range.push_back(reduce_axes[red_counter]); red_counter++; } else { eval_range.push_back(indices[arg_counter]); arg_counter++; } } auto square = [](const PrimExpr& x) { return x * x; }; return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr); };
auto temp_x_x2 = tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce);
//获得平方和 auto temp_x2 = temp_x_x2[1];
//平方和求均值时要除以的元素的数量 auto reduce_extent = make_const(data->dtype, 1); for (int i : real_axis) { reduce_extent *= data->shape[i]; }
auto RMS_norm_func = [&](const Array<Var>& indices) { Array<Var> reduce_indices, non_reduce_indices; for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) { if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { reduce_indices.push_back(indices[i]); } else { non_reduce_indices.push_back(indices[i]); } }
auto var = temp_x2(non_reduce_indices) / reduce_extent ; auto RMS_norm = data(indices) * tvm::rsqrt(var + make_const(var->dtype, epsilon)); //tvm::rsqrt 即 1/tvm::sqrt RMS_norm = topi::multiply(RMS_norm, gamma(reduce_indices));
"""Test code for RMS_norm.""" import numpy as np import pytest import tvm from tvm import te from tvm import topi from tvm.topi.utils import get_const_tuple import tvm.topi.testing