接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。
以下是提取一张jpg图像的特征的程序:
# -*- coding: utf-8 -*-
import os.path
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
features_dir = './features'
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
transform1 = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor() ]
)
img = Image.open(img_path)
img1 = transform1(img)
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extr