做过pytorch框架下机器学习的人,应该或多或少都知道pytorch框架下训练模型是基于tensor的。所以在冻结参数时通常需要两步操作,首先设置模型中想要冻结的参数的requires_grad属性为Fasle,然后在优化器中把想要冻结的参数去掉,或者说只将想训练的的权值提供给优化器。
但mindspore下是通过全图梯度计算实现训练的,所以如果希望冻结参数,只需要把想要训练的参数给到优化器以及求梯度的函数grad即可。
其中grad在当前的mindspore2.0.0alpha版本中是mindspore.grad()函数,一般是在TrainOneStepCell中使用。
一般在训练的时候如果希望模型能返回除了loss之外的其他值,需要使用到has_aux参数,然后把所有的loss加和到一个loss变量中,放在返回值的第一个位置。
weights就是放参数的地方,一般通过trainable_params()函数获取,返回的是一个参数列表,多个参数列表直接通过+运算符连接起来就可以了。