一、L2范数介绍
二、L2范数计算
2.1 导入包
import torch.nn as nn
import torch
2.2 L2范数计算
L2_norms = {}
for name, param in local_model_paras.items():
L2_norms[name] = torch.norm(param, p=2)
2.3 注意的地方
- 如果参数里面没有.long()形式的,上述代码就可以实现
- 对于 torch.norm 函数,它要求输入的张量数据类型应为浮点型或复数型,而不支持 long 类型
- 如果张量是 long 类型,并且无法直接转换为浮点型或复数型,需要手动计算 L2 范数
- be like下面这样
L2_norms = {}
for name, param in local_model_paras.items():
if name == 'xxxx':
param_squared = param.float().pow(2)
sum_squared = param_squared.sum()
L2_norms[name] = torch.sqrt(sum_squared)
else:
L2_norms[name] = torch.norm(param, p=2)