目的:
大尺寸医学图像[3, 897, 1196]放到一个batch中
通过循环进行分割[12, 3, 299, 299],叠加分割之后为[24, 3, 299, 299]
上图左侧的顺序有问题,现需要一些操作对其进行上图的转换,编写代码。
for i, (inputs, labels,img_path) in enumerate(train_loader):
D_L = len(inputs)
# labels = labels.to(device)
img_tensor = Variable(inputs.to(device), volatile=True)
crop_img = torch.zeros((4, 3))
print("第",i,"个")
#大尺寸图片[3, 897, 1196]分割成[12, 3, 299, 299]
for j in range(3):
for ii in range(4):
if ii == 0 and j == 0:
crop_img = img_tensor[:, :, 299 * j:299 * (j + 1), 299 * ii:299 * (ii + 1)]
else:
crop_img = torch.cat(
(crop_img, img_tensor[:, :, 299 * j:299 * (j + 1), 299 * ii:299 * (ii + 1)]), 0)
print("crop_img:", crop_img.shape)
# 按索引重新排列,将[24*batch_size,3,299,299],按大尺寸图片以次排列
trans_image = torch.zeros([12, 3, 299, 299]).to(device)
for ind in range(D_L):
index = np.arange(ind, 12 * D_L, D_L)
index = torch.from_numpy(index).long().to(device)
index = index.unsqueeze(1).unsqueeze(1).unsqueeze(1)
index = index.expand(12, 3, 299, 299)
t = torch.gather(crop_img, 0, index)
if ind == 0:
trans_image = t
else:
trans_image = torch.cat([trans_image, t], 0)
print("233", trans_image.shape)
Tensorflow中的gather与pytorch中的gather的功能不同,如果pytorch有和tf. Gather功能相同的函数就可以直接使用。
tf.gather()
一维
多维按行
多维按列
tf.gather_nd()
torch.gather(input, dim, index, out=None)
作用:收集输入的特定维度指定位置的数值
input(tensor): 待操作数。不妨设其维度为(x1, x2, …, xn)
dim(int): 待操作的维度。
index(LongTensor): 如何对input进行操作。其维度有限定,例如当dim=i时,index的维度为(x1, x2, …y, …,xn),既是将input的第i维的大小更改为y,且要满足y>=1(除了第i维之外的其他维度,大小要和input保持一致)。
out: 注意输出和index的维度是一致的
使用torch.gather()构造和tf.gather()相同功能的函数
按第0维
按第1维
按第2维
torch.gather()实现tf.gather()