【Pytorch】自己的工具类:TensorList

class TensorList(object):
    def __init__(self, tensors):
        self.tensors = tensors
        self.len = [len(t) for t in tensors]

    def __len__(self):
        return len(self.tensors)

    def tensors_len(self):
        return torch.LongTensor([len(t) for t in self.tensors]).to(self.tensors[0].device)

    def to(self, *args, **kwargs):
        tensors = [tensor.to(*args, **kwargs) for tensor in self.tensors]
        return TensorList(tensors)

    def cat(self):
        return torch.cat(self.tensors)

    def split(self, inst_embeddings):
        return torch.split(inst_embeddings, self.len, dim=0)

使用

def xyxy2xcycwh(xyxy):
	"""
	xyxy: shape=(n, 4)
	"""
    x1, y1, x2, y2 = xyxy
    xc, yc = (x1+x2)/2, (y1+y2)/2
    w, h = x2-x1, y2-y1
    return torch.stack([xc, yc, w, h], dim=0)

import torch
img1_bbox = torch.zeros((1, 4)) + 1
img2_bbox = torch.zeros((3, 4)) + 2
img3_bbox = torch.zeros((2, 4)) + 3

imgs_bbox = TensorList([img1_bbox, img2_bbox , img3_bbox ])

cat_imgs_bbox = imgs_bbox.cat()  # for loop too slow when too many list
cat_imgs_bbox = xyxy2xcycwh(cat_imgs_bbox)
print(cat_imgs_bbox.shape)
img1_bbox,img2_bbox,img3_bbox = imgs_bbox.split(cat_imgs_bbox) 
print(img1_bbox)
print(img2_bbox)
print(img3_bbox)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值