pytorch-obtain feature maps from network

1.

import torch 
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable

import time

class toyNet(nn.Module):

	def __init__(self, pretrained_model, layers):
		super(toyNet, self).__init__()
		
		
		self.net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]])
		self.net2 = nn.Sequential(*list(pretrained_model.children())[layers[0]:layers[1]])
		self.net3 = nn.Sequential(*list(pretrained_model.children())[layers[1]:layers[2]])
	

	def forward(self, x):
		
		out1 = self.net1(x)
		out2 = self.net2(out1)
		out3 = self.net3(out2)
	
		return out1, out2, out3


def get_features(pretrained_model, x, layers = [3, 4, 7]):

	net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]])
#	print net1
	out1 = net1(x)

	net2 = nn.Sequential(*list(pretrained_model.children())[layers[0]:layers[1]])
#	print net2
	out2 = net2(out1)

	net3 = nn.Sequential(*list(pretrained_model.children())[layers[1]:layers[2]])
	out3 = net3(out2)

	return out1, out2, out3


x = Variable(torch.rand(1,3,224,224))
net = models.resnet18(pretrained=True)

if torch.cuda.is_available():
	x = x.cuda()
	net = net.cuda()

start = time.time()

o1, o2, o3 = get_features(net, x)

print time.time() - start


print o1.data.size()
print o2.data.size()
print o3.data.size()

print '----------------------------------------------'

start = time.time()

toynet = toyNet(net, [3,4,7])
y1, y2, y3 = toynet(x)

print time.time() - start

print y1.data.size()
print y2.data.size()
print y3.data.size()



In [50]: class getfea(nn.Module):     
    ...:     def __init__(self, submoduls):
    ...:         super(fea, self).__init__()
    ...:         self.submoduls = submoduls
    ...:     def forward(self,x):
    ...:         output = []
    ...:         for i,m in enumerate(self.submoduls.children()):
    ...:             x = m(x)
    ...:             if i == 4 or i==7:
    ...:                 output += [x]
    ...:         return output+[x]



-------------------------------------------------reference---------------------------------
1. https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/23


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值