26 Star 163 Fork 148

Ascend / MindSpeed

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
rms_norm.md 1.13 KB
一键复制 编辑 原始数据 按行查看 历史
邓佳 提交于 2024-05-27 09:13 . !359change ascendspeed to mindspeed

rms_norm对外接口

前向接口:

输入:

  • x:必选输入,数据类型float16, bfloat16, float32
  • gamma:必选输入,数据类型float16, bfloat16, float32

输出:

  • rstd:中间变量,数据类型float32
  • out:必选输出,数据类型float16, bfloat16, float32

属性:

  • epsilon:可选属性,数据类型float,缺省1e-6。

反向接口:

输入:

  • dy:必选输入,数据类型float16, bfloat16, float32
  • x:必选输入,数据类型float16, bfloat16, float32
  • rstd:必选输入,数据类型float32
  • gamma:必选输入,数据类型float16, bfloat16, float32

输出:

  • dx:必选输出,数据类型float16, bfloat16, float32
  • dgammma:必选输出,数据类型float32

案例

import os
import torch
import torch_npu
import numpy as np
import math
from mindspeed.op_builder import RmsNormOpBuilder

x = torch.rand([2, 4]).to(torch.float).npu()
gamma = torch.rand([4]).to(torch.float).npu()
# 正向接口案例
mindspeed_ops = RmsNormOpBuilder().load()
out = mindspeed_ops.rms_norm(x, gamma)
# 反向接口案例
out.backward(torch.ones(out.shape).npu())
1
https://gitee.com/ascend/MindSpeed.git
git@gitee.com:ascend/MindSpeed.git
ascend
MindSpeed
MindSpeed
master

搜索帮助