Encoder-Decoer模型共享embedding矩阵,embedding矩阵的参数更新问题

最近做生成式问答,尝试用bert做encoder,transformer-decoder做decoder框架来做。遇到一个问题,就是我想让decoder共享bert的embedding矩阵,但是由于设置了decoder和encoder学习速率不同,因此,我不知道embedding矩阵参数如何更新?会不会收到decoder端的影响,于是做了下面的实验。

import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, ):
        super(Encoder, self).__init__()
        self.embeddings = nn.Embedding(100, 50)
        self.fc = nn.Linear(50, 1)

    def forward(self, input):

        feature = self.embeddings(input)
        feature = self.fc(feature)

        return feature


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.embeddings = None
        self.fc = nn.Linear(50, 1)

    def forward(self, input):
        feature = self.embeddings(input)
        feature = self.fc(feature)

        return feature


class myModel(nn.Module):
    def __init__(self):
        super(myModel, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

        self.decoder.embeddings = self.encoder.embeddings

    def forward(self, enc_input, dec_input):
        enc_ = self.encoder(enc_input)
        dec_ = self.decoder(dec_input)

        return enc_.sum() + dec_.sum()


model = myModel()

enc_param = []
dec_param = []
for n,p in list(model.named_parameters()):
    if n.split('.')[0] == 'encoder':
        enc_param.append((n, p))
    else:
        dec_param.append((n, p))

optimizer_grouped_parameters = [
            # bert other module
            {"params": [p for n, p in enc_param],
             'lr': 0.01},
            {"params": [p for n, p in dec_param],
             'lr': 0.001},
        ]


optim = torch.optim.SGD(optimizer_grouped_parameters)

enc_input = torch.arange(0, 10).unsqueeze(0)
dec_input = torch.arange(5, 15).unsqueeze(0)

loss = model(enc_input, dec_input)

optim.zero_grad()
loss.backward()
optim.step()


print(id(model.encoder.embeddings))
print(id(model.decoder.embeddings))

print([n for (n, p) in dec_param])
print([n for (n, p) in enc_param])

'''输出
140206391178048
140206391178048
['decoder.fc.weight', 'decoder.fc.bias']
['encoder.embeddings.weight', 'encoder.fc.weight', 'encoder.fc.bias']

'''


根据打印结果,发现embedding只在encoder的参数组中,而且decoder的embedding与encoder的embedding在内存中地址一样,说明是共享的,所以我的担心是多虑的。

  • 5
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值