pytorch 将模型作为特征提取器(提取中间层特征)

3 篇文章 0 订阅

目的

需要加载自己训练好的最好模型作为一个特征提取器,也就是说需要提取最后一层全连接层输出的内容。

解决方法

参考了两个方法(详见文末)

设参数直接提取

准备一个toy model来说明。

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.cl1 = nn.Linear(25, 60)
        self.cl2 = nn.Linear(60, 16)

    def forward(self, x):
        x = F.relu(self.cl1(x))
        x = self.cl2(x)
        ################################
        self.last_feature = x.detach()
        ################################
        x =F.relu(x)
        return x

x = torch.randn(1, 25)
model = MyModel()
output = model(x)
print(model.last_feature)
"""
tensor([[-0.0670,  0.1209,  0.5386, -0.0052, -0.2690, -0.0397, -0.0492,  0.0916,
          0.3837, -0.5325,  0.3419, -0.3190,  0.0589, -0.1058, -0.1944, -0.0929]])
"""

在两个注释条中间,通过设置了一个self.last_feature来保存cl2层的输出结果。

设hook函数提取

同样使用上方的toy model,但是需要额外增加几行代码。

activation = {}
# 告诉模型在哪一层需要detach
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook
    
model.cl2.register_forward_hook(get_activation('cl2'))
print(activation['cl2'])
"""
tensor([[-0.0670,  0.1209,  0.5386, -0.0052, -0.2690, -0.0397, -0.0492,  0.0916,
          0.3837, -0.5325,  0.3419, -0.3190,  0.0589, -0.1058, -0.1944, -0.0929]])
"""

总结

hook函数相对通用,举个例子,还是以上面的model为例,但是稍微修改了一下forward函数的表达形式:

def forward(self, x):
        x = F.relu(self.cl1(x))
        x =F.relu(self.cl2(x))
        self.last_feature = x.detach()
        return x
 # print(model.last_feature)
"""
tensor([[0.0000, 0.1209, 0.5386, 0.0000, 0.0000, 0.0000, 0.0000, 0.0916, 0.3837,
         0.0000, 0.3419, 0.0000, 0.0589, 0.0000, 0.0000, 0.0000]])
"""

如果习惯性将激活函数连写,那么hook函数还是能够提取出正确的特征值,但是人为设置参数的方法则需要放在cl2层直接输出的后面。

Reference:

  1. https://www.zhihu.com/question/68384370
  2. https://discuss.pytorch.org/t/how-can-l-load-my-best-model-as-a-feature-extractor-evaluator/17254/6
  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值