本文是最近研究Prompt Tuning的时候,想使用 torchsummary 来打印一下模型及可训练参数。
于是有了下面的这个例子。
import torch
import torch.nn as nn
from torchsummary import summary
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.embedding = nn.Embedding(50,32)
def forward(self, input):
# print(input)
drop_1 = self.embedding(input)
return drop_1
class MyModel1(nn.Module):
def __init__(self):
super(MyModel1, self).__init__()
self.model = MyModel()
""""""
for param in self.model.parameters():
param.requires_grad = False
self.embedding = nn.Embedding(50,32)
def forward(self, input):
out0 = self.model(input)
out1 = self.embedding(input)
a = out0[:, :25].clone()
b = out0[:, 25:].clone()
bools1 = torch.cat((a.fill_(False), b.fill_(True)), axis = 1)
bools2 = torch.cat((a.fill_(True), b.fill_(False)), axis = 1)
print(bools1.shape)
print(out0.shape)
out2 = out0 * bools1 + out1 * bools2
return out2
if __name__ == '__main__':
i = torch.randint(0, 50, (32, 50))
m = MyModel1()
summary(m, (3, 32, 50), dtypes=[torch.long])
m.forward(i)
结果如下:
上面的代码段包括了两部分,一是网络层的冻结与替换,另一个则是使用 summary 打印网络模型及参数。
用于 Prompt 的话,清华的 P-tuning 便是调整了输入的 embedding ,不训练 bert 模型本身参数,而仅仅训练这些新的 embedding 即可。