python使用训练好的模型进行特征提取

文章目录


前言

因为最近可用的开源代码库里找到了一个网络,因为要提高网络训练速度,网络的输入就改成了一组组特征,所以就需要提前利用特征提取模型进行特征提取,并将特征存为数组。

代码

import os
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
import torch.utils.data
def main(device, in_jit_model, out_feat_dir):



    model = torch.jit.load(in_jit_model, 'cpu')
    model.eval()
    model = model.to(device)
    f = open('.../xxx.txt', 'r')

    files = f.readlines()  # 读取整个文件所有行,保存在 list 列表中

    feats = []
    tmp = '0'
    n = 0
    for filename in files:
            filename = filename[:-1]
            image = Image.open('.../' + filename)
            image = image.convert("RGB")
            image = np.array(image)
            # image = np.expand_dims(image, axis=0)



            image = torch.tensor(image)
            image = image.float()
            image /= 255
            image = image.permute(2,1,0)
            # image = data_transform(image)
            image = torch.tensor(np.expand_dims(image, axis=0))
            B, N = [1,1]
            image = image.contiguous().to(device)
            t_y = model(image)

            assert t_y.ndim == 2

            t_y = t_y.reshape(B, N, t_y.shape[-1])

            if filename[:12]==tmp :
                feats.extend(t_y.cpu().detach().numpy())
                n = n+1
                print(filename+' get {0} feat'.format(n))

            elif tmp =='0':
                feats.extend(t_y.cpu().detach().numpy())
                n = n + 1
                print(filename+' get {0} feat'.format(n))
            else:
                files2 = open('.../xxx.txt',
                    'r')
                files2 = files2.readlines()  # 读取整个文件所有行,保存在 list 列表中
                for k in files2:
                    k = k[:-1]
                    if tmp == k[:12]:
                        feats = np.stack(feats, 0)
                        out_feat_file = f'{out_feat_dir}/{k}.svs.npy'
                        os.makedirs(os.path.dirname(out_feat_file), exist_ok=True)
                        np.save(out_feat_file, feats, allow_pickle=False)
                        n = 0
                        print(k+"----------save 1 feats")
                        feats = []
                        feats.extend(t_y.cpu().detach().numpy())
                        break



            tmp = filename[:12]



if __name__ == '__main__':

    device = 'cuda:0'
    in_jit_model = '.../xxx.pt'
    out_feat_dir = '..../...'
    main(
        device=device,
        in_jit_model=in_jit_model,
        out_feat_dir=out_feat_dir
    )

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值