代码拉取完成,页面将自动刷新
CLASS MatmulAllReduceAddRmsNorm()
计算逻辑: $$ mmOut = allReduce(x1*x2 + bias) $$ $$ y = mmOut + residual $$ $$ normOut = \frac{y}{RMS(y)}*gamma, RMS(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} y_{i}^{2} + epsilon} $$
输入:
输出:
输入:
输出:
输入:
输出:
x2
仅支持最后两轴转置情况下的非连续tensor传入,x1
、residual
、gamma
等输入仅支持连续tensorx1
支持两维或者三维,其维度为 (b, s, k)
或者 (s, k)
x2
仅支持两维,其维度为 (k, n)
,x1
和 x2
的轴满足matmul算子入参要求,k轴相等bias
在非空情况下为1维,其维度为 (n)
residual
仅支持三维,其维度为 (b, s, n)
,当 x1
为两维时,residual
的 (b * s)
等于 x1
的 s
,当 x1
为三维时,residual
的 (b * s)
等于 x1
的 (b * s)
;residual
的最后一维与x2
的最后一维相等gamma
仅支持一维,其维度为 (n)
,gamma
的最后一维与 residual
的最后一维相等reduce_op
仅支持 sum
x1
、x2
、bias
(若支持)、residual
、gamma
计算输入的数据类型要一致residual
类型为 FLOAT16
,dequant_scale
的类型为 INT64
、UINT64
(需通过 torch_npu.npu_trans_quant_param()
接口对 dequant_scale
进行处理);若输出 residual
类型为 BFLOAT16
,dequant_scale
的类型为 BFLOAT16
。dequant_scale
满足两种模式:
per_tensor
模式:(1,)
per_channel
模式:(1, n)
或 (n,)
x1
、x2
数据类型为 int8
,bias
(若支持)数据类型为 int32
,residual
、gamma
计算输入的数据类型要一致。k
、n
的范围为[1,65535]
antiquant_scale
满足三种模式:
per_tensor
模式:(1,)
per_channel
模式:(1, n)
或 (n,)
per_group
模式:(ceil(k,antiquant_group_size),n)
antiquantOffset
若非空,shape 与 antiquant_scale
一致。x2
的数据类型需为 int8
,x1
、bias
(若支持)、residual
、gamma
、antiquant_scale
、antiquant_offset
计算输入的数据类型要一致。antiquant_group_size
取值满足取值范围[32, min(k-1, INT_MAX)]
且为32倍数。import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
from mindspeed.ops.npu_mm_all_reduce_add_rms_norm import npu_mm_all_reduce_add_rms_norm
def run_mm_all_reduce_add_rms_norm(rank, world_size, master_ip, master_port, x1_shape, x2_shape, residual_shape,
gamma_shape, dtype):
torch_npu.npu.set_device(rank)
init_method = 'tcp://' + master_ip + ':' + master_port
dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcom_info = default_pg._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
else:
hcom_info = default_pg.get_hccl_comm_name(rank)
x1 = torch.randn(x1_shape, dtype=dtype).npu()
x2 = torch.randn(x2_shape, dtype=dtype).npu()
residual = torch.randn(residual_shape, dtype=dtype).npu()
gamma = torch.randn(gamma_shape, dtype=dtype).npu()
epsilon = 0.000001
y, norm_out = npu_mm_all_reduce_add_rms_norm(x1=x1, x2=x2, residual=residual, gamma=gamma, hcom=hcom_info,
reduce_op='sum', epsilon=epsilon)
print("y:", y)
print("norm_out:", norm_out)
if __name__ == "__main__":
worksize = 8
master_ip = "127.0.0.1"
master_port = '50001'
b, s, k, n = 4, 1024, 1024, 8192
x1_shape = (b, s, k)
x2_shape = (k, n)
residual_shape = (b, s, n)
gamma_shape = (n)
dtype = torch.float16
mp.spawn(run_mm_all_reduce_add_rms_norm,
args=(worksize, master_ip, master_port, x1_shape, x2_shape, residual_shape, gamma_shape, dtype),
nprocs=worksize)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。