这篇博文相当于是对上一篇博文Pytorch:利用预训练好的VGG16网络提取图片特征 的补充,本文中提到的提取方式与前文中的不同。
另外,因为TorchVision提供的训练好了的ResNet效果不好,所以本文中将会使用由ruotianluo提供的从Caffe转换过来的ResNet模型(具体可以看这个repo,如果好奇怎么转换的话)。
代码
以下代码节选自pytorch-vqa的preprocess-images.py
,作者是Cyanogenoid。
import h5py
from torch.autograd import Variable
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.models as models
from tqdm import tqdm
import config
import data
import utils
from resnet import resnet as caffe_resnet
class Net(nn.Module):
def __init__