修改特征图类型tuple转Tensor

文章介绍了在替换主干网络时遇到的特征图类型不一致的问题,主干网络A输出为tuple,而网络B输出为torch.Tensor。为了解决这个问题,可以通过torch.cat()函数将tuple类型的特征图转换为torch.Tensor,具体方法是根据特征图的数量进行单个或多个特征图的拼接。这允许将轻量级网络的输出适应原有的处理流程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

在修改模型结构时,本来想着简单替换主干网络,用轻量级结构的替换原来的复杂模型,但是过程没想象中的顺利;其中比较关键的一点是两个主干网络输出的特征图类型不一致。

问题描述

主干网络A(轻量级),它输出特征图的类型是tuple,输出维度是[1, 3, 640, 640];

主干网络B(复杂的),它输出特征图的类型是torch.Tensor,输出维度也是[1, 3, 640, 640];

但是如果直接把主干网络B替换为主干网络A,后面接着原来的特征提取结构和任务头,会报错的。

tuple 转 torch.Tensor

把主干网络B替换为主干网络A后,加多一步操作,将输出特征图从tuple 转 torch.Tensor即可。

转换的基本思路是:使用 torch.cat( ) 把特征图进行拼接起来,通常是在维度 dim=0 进行拼接的。

A、当特征图的tuple数量为1

import torch

# 假设模型输出的特征图为 feature_map, feature_map 是一个 tuple

# 获取特征图个数
num_maps = len(feature_map)

# 打印原来的特征图信息
print("type feature_raw:", type(outs))
for out in feature_map:
    print(out.size())
print("len feature_raw:", num_maps)

# 按第 0 维度拼接特征图
feature_map = torch.cat([fm for fm in feature_map], dim=0)

# 检查特征图类型
print("type feature_map:", type(feature_map))
# 输出: <class 'torch.Tensor'>

# 检查特征图维度
print("size feature_map:", feature_map.size())

示例输出:

type feature_raw: <class 'tuple'>
torch.Size([8, 32, 640, 640])
len feature_raw: 1


type feature_map: <class 'torch.Tensor'>
feature_map: torch.Size([8, 32, 640, 640])

B、当特征图的tuple数量为多个

如果主干网络输出的特征图类型为tuple,而且它包含多个特征图。我们想把它们变为一个torch.Tensor,可以使用torch.cat函数把它们拼接在一起。 

import torch

# 假设模型输出的特征图为 feature_map, feature_map 是一个 tuple

# 获取特征图个数
num_maps = len(feature_map)

# 打印原来的特征图信息
print("type feature_raw:", type(outs))
for out in feature_map:
    print(out.size())
print("len feature_raw:", num_maps)

# 按第 0 维度拼接特征图
feature_map = torch.cat([fm.unsqueeze(0) for fm in feature_map], dim=0)

# 检查特征图类型
print("type feature_map:", type(feature_map))
# 输出: <class 'torch.Tensor'>

# 检查特征图维度
print("size feature_map:", feature_map.size())

这样就可以将输出的特征图类型由tuple变为torch.Tensor了。拼接时,通过unsqueeze(0)把每个特征图在第0维度上增加一维,这样才能用torch.cat进行拼接。 

示例输出:

type feature_raw: <class 'tuple'>
torch.Size([8, 32, 640, 640])

torch.Size([8, 32, 640, 640])

torch.Size([8, 32, 640, 640])

torch.Size([8, 32, 640, 640])
len feature_raw: 1


type feature_map: <class 'torch.Tensor'>
feature_map: torch.Size([4, 8, 32, 640, 640])

分享完成,欢迎交流~

 

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一颗小树x

您的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值