Pytorch 实现tf.gather()函数的功能

目的:

大尺寸医学图像[3, 897, 1196]放到一个batch中

通过循环进行分割[12, 3, 299, 299],叠加分割之后为[24, 3, 299, 299]

 

上图左侧的顺序有问题,现需要一些操作对其进行上图的转换,编写代码。

        for i, (inputs, labels,img_path) in enumerate(train_loader):
            D_L = len(inputs)
            # labels = labels.to(device)
            img_tensor = Variable(inputs.to(device), volatile=True)
            crop_img = torch.zeros((4, 3))
            print("第",i,"个")

            #大尺寸图片[3, 897, 1196]分割成[12, 3, 299, 299]
            for j in range(3):
                for ii in range(4):
                    if ii == 0 and j == 0:
                        crop_img = img_tensor[:, :, 299 * j:299 * (j + 1), 299 * ii:299 * (ii + 1)]
                    else:
                        crop_img = torch.cat(
                            (crop_img, img_tensor[:, :, 299 * j:299 * (j + 1), 299 * ii:299 * (ii + 1)]), 0)
            print("crop_img:", crop_img.shape)

            # 按索引重新排列,将[24*batch_size,3,299,299],按大尺寸图片以次排列
            trans_image = torch.zeros([12, 3, 299, 299]).to(device)
            for ind in range(D_L):
                index = np.arange(ind, 12 * D_L, D_L)
                index = torch.from_numpy(index).long().to(device)
                index = index.unsqueeze(1).unsqueeze(1).unsqueeze(1)
                index = index.expand(12, 3, 299, 299)
                t = torch.gather(crop_img, 0, index)
                if ind == 0:
                    trans_image = t
                else:
                    trans_image = torch.cat([trans_image, t], 0)
            print("233", trans_image.shape)

 

Tensorflow中的gather与pytorch中的gather的功能不同,如果pytorch有和tf. Gather功能相同的函数就可以直接使用。

 

tf.gather()

 一维

多维按行

多维按列

tf.gather_nd()

torch.gather()

 torch.gather(input, dim, index, out=None)

作用:收集输入的特定维度指定位置的数值

参数:

input(tensor):   待操作数。不妨设其维度为(x1, x2, …, xn)
dim(int):   待操作的维度。
index(LongTensor):   如何对input进行操作。其维度有限定,例如当dim=i时,index的维度为(x1, x2, …y, …,xn),既是将input的第i维的大小更改为y,且要满足y>=1(除了第i维之外的其他维度,大小要和input保持一致)。
out:   注意输出和index的维度是一致的

使用torch.gather()构造和tf.gather()相同功能的函数

按第0维

 

按第1维

按第2维

 

 

torch.gather()实现tf.gather()

 

 


 

 

 

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

东城西阙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值