【踩坑记录】pytorch 自定义嵌套网络时部分网络有梯度但参数不更新

问题描述

使用如下的自定义的多层嵌套网络进行训练:

class FC1_bot(nn.Module):
    def __init__(self):
        super(FC1_bot, self).__init__()
        self.embeddings = nn.Sequential(
        	nn.Linear(10, 10)
        )
       
    def forward(self, x):
        emb = self.embeddings(x)
        return emb

    
class FC1_top(nn.Module):
    def __init__(self):
        super(FC1_top, self).__init__()
        self.prediction = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(10, 10)
        )
        
    def forward(self, x):
        logit = self.prediction(x)
        return logit


class FC1(nn.Module):
    def __init__(self, num):
        super(FC1, self).__init__()
        self.num = num

        self.bot = []
        for _ in range(num):
            self.bot.append(FC1_bot())

        self.top = FC1_top()
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = list(x)
        emb = []
        for i in range(self.num):
            emb.append(self.bot[i](x[i]))

        agg_emb = self._aggregate(emb)
        logit = self.top(agg_emb)

        pred = self.softmax(logit)

        return emb, pred
    
    def _aggregate(self, x):
        # Note: x is a list of tensors.
        return torch.cat(x, dim=1)

训练的代码如下:

num = 4
model = FC1(num)
optimizer_entire = torch.optim.SGD(model.parameters(), lr=0.01)

def train(self):
	# train entire model
	self.model.train()

	for epoch in range(self.args.epochs):
		pred = self.model(data)
		loss = torch.nn.CrossEntropyLoss(pred, labels)
		
		# zero grad for all optimizers
        optimizer_entire.zero_grad()

		loss.backward()
		
		# update parameters for all optimizers
        optimizer_entire.step()

解决办法

需要给所有用到的模型参数都设置optimizer,否则只有top部分的参数在训练,底层的会得到gradient,但parameter不会更新。

num = 4
model = FC1(num)
optimizer_entire = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer_top = torch.optim.SGD(model.top.parameters(), lr=0.01)
optimizer_bot = []
for i in range(num):
	optimizer_passive.append(torch.optim.SGD(model.passive[i].parameters(), lr=0.01))

def train(self):
	# train entire model
	self.model.train()
	self.model.top.train()
	for i in range(self.args.num):
	    self.model.bot[i].train()

	for epoch in range(self.args.epochs):
		pred = self.model(data)
		loss = torch.nn.CrossEntropyLoss(pred, labels)
		
		# zero grad for all optimizers
        optimizer_entire.zero_grad()
        optimizer_top.zero_grad()
        for i in range(num):
            optimizer_bot[i].zero_grad()

        loss.backward()
		
		# update parameters for all optimizers
        optimizer_entire.step()
        optimizer_top.step()
        for i in range(num):
            optimizer_bot[i].step()
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

D-A-X

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值