class Para(object):
enc_quantize_level=2
enc_value_limit = 2
class STEQuantize(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, args):
ctx.save_for_backward(inputs)
ctx.args = args
x_lim_abs = args.enc_value_limit
x_lim_range = 2.0 * x_lim_abs
x_input_norm = torch.clamp(inputs, -x_lim_abs, x_lim_abs) # -x_lim_abs就是qmin, x_lim_abs是qmax torch.clamp(input, min, max, out=None)将输入input张量每个元素的范围限制到区间 [min,max],返回结果到一个新张量。
if args.enc_quantize_level == 2:
outputs_int = torch.sign(x_input_norm)
else:
outputs_int = torch.round((x_input_norm +x_lim_abs) * ((args.enc_quantize_level - 1.0)/x_lim_range)) * x_lim_range/(args.enc_quantize_level - 1.0) - x_lim_abs
#返回一个新张量,将输入input张量的每个元素舍入到最近的整数。
return outputs_int
@staticmethod
def backward(ctx, grad_output):
if ctx.args.enc_clipping in ['inputs', 'both']:
input, = ctx.saved_tensors
grad_output[input>ctx.args.enc_value_limit]=0
grad_output[input<-ctx.args.enc_value_limit]=0
if ctx.args.enc_clipping in ['gradient', 'both']:
grad_output = torch.clamp(grad_output, -ctx.args.enc_grad_limit, ctx.args.enc_grad_limit)
if ctx.args.train_channel_mode not in ['group_norm_noisy', 'group_norm_noisy_quantize']:
grad_input = grad_output.clone()
else:
# Experimental pass gradient noise to encoder.# 实验将梯度噪声传递给编码器。
grad_noise = snr_db2sigma(ctx.args.fb_noise_snr) * torch.randn(grad_output[0].shape, dtype=torch.float)
ave_temp = grad_output.mean(dim=0) + grad_noise
ave_grad = torch.stack([ave_temp for _ in range(ctx.args.batch_size)], dim=2).permute(2,0,1) #.permute 可以同时多次交换tensor的维度 .stack 在维度上连接(concatenate)若干个张量
grad_input = ave_grad + grad_noise
return grad_input, None
a=torch.randn(2,3)
p = Para()
s=STEQuantize.apply(a,p)
print(s)