#features(torch tensor):a 2D torch float tensor withshape(batch_size,feat_dimension)
#labels(torch long tensor):1D torch long tensor withshape(batch_size)
#alpha(float):weight for center loss
方法1
loss =center_loss(features, labels)* alpha + other_loss
optimizer_centloss.zero_grad()
loss.backward()
#multiple(1./alpha)in order to remove the effect of alpha on updating centers
for param in center_loss.parameters():
param.grad.data *=(1./alpha)
optimizer_centloss.step()
方法100
loss =center_loss(features, labels)* alpha + other_loss
optimizer.zero_grad()
loss.backward()for param in center_loss.parameters():
# lr_cent is learning rate for center loss, e.g. lr_cent =0.5
param.grad.data *=(lr_cent /(alpha * lr))
optimizer.step()
loss分两部分更新梯度的一百种方法variable introduced#features (torch tensor):a 2D torch float tensor with shape(batch_size,feat_dimension)#labels (torch long tensor):1D torch long tensor with shape (batch_size)...