Pytorch torch.save() 保存特征向量

文章目录

1 需求

在这里插入图片描述

存取上述特征向量

2 实现

  • 数据结构: 使用list存储这些向量,[(r_emb, query), ...]
  • 工具: torch.save()tensor保存为.pth,存取对象是字典
"""
保存特征向量,推荐使用torch保存,直接保存为tensor
"""
import torch


def save_feature(feature_list, feature_path):
    feature = {

    }
    for i, (r_emb, query) in enumerate(feature_list):
        feature[f"r_emb_{i}"] = r_emb
        feature[f"query_{i}"] = query

    torch.save(feature, feature_path)
    pass

def load_feature(feature_path):
    feature = torch.load(feature_path)
    feature_list = []
    for i in range(len(feature.keys()) // 2):
        r_emb = feature[f"r_emb_{i}"]
        query = feature[f"query_{i}"]
        feature_list.append((r_emb, query))
        ...
    return feature_list
    ...

if __name__ == "__main__":
    r_emb_1 = torch.randn((32, 75, 512))
    query_1 = torch.randn((32, 22, 512))

    r_emb_2 = torch.randn((32, 75, 512))
    query_2 = torch.randn((32, 26, 512))

    feature_list = [(r_emb_1, query_1), (r_emb_2, query_2)]
    feature_path = "./save_feature.pth"

    # save_feature(feature_list, feature_path)
    feature = load_feature(feature_path)
    print("query_1 shape:", feature[0][1].shape)
    pass

在这里插入图片描述


  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值