Pytorch 搭建 SearchTransfer

Pytorch 搭建 SearchTransfer

SearchTransfer源自论文Learning Texture Transformer Network for Image Super-Resolution的代码

[paper] [code]

本文记录了复现transformer module中遇到的一些用法

关键函数

  • torch.nn.functional.unfold
  • torch.nn.functional.fold
  • torch.expand
  • torch.gather

unfold展开方便做blocks间的attention,然后利用得到的相似图计算索引来提取ref_unfold中的信息,最后用fold还原

1. unfold

unfold用 与nn.Conv2d相同的滑动窗口 将输入划分为一个个blocks

import torch
import torch.nn.functional as F

x = torch.rand((1, 3, 5, 5))
x_unfold = F.unfold(x, kernel_size=3, padding=1, stride=1)
print(x.shape)	# torch.Size([1, 3, 5, 5])
print(x_unfold.shape)	# torch.Size([1, 27, 25])

x的形状为(batch,channel,H,W),可以看到x_unfold的shape为(batch,k x k x channel, number_blocks)

k是kernel_size,k x k x channel表示一个blocks中的像素个数

number_blocks是在给定kernel_size, padding,stride的情况下,可以滑出几个block

2. fold

fold的用法与unfold相反,是将一个个blocks还原回(batch,channel,H,W)的样子

k = 6
s = 2
p = (k - s) // 2
H, W = 100, 100

x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
print(x_unfold.shape)	# torch.Size([1, 108, 2500])
print(x_fold.shape)		# torch.Size([1, 3, 10, 10])
print(x.mean())			# tensor(0.5012)
print(x_fold.mean())	# tensor(4.3924)

可以看到虽然形状是还原了,但x和x_fold的值域发生了变化,这是因为unfold的时候一个位置(1x1xchannel)可以出现在多个blocks中,因此fold的时候会求和这些重叠的位置,导致了数据不一致。因此得出x_fold后还需要除以重叠数才能得出原始数据范围。k=6,s=2时,一个位置会出现在3*3=9个blocks中(窗口上下左右滑动)。

x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p) / (3.*3.)
print(x_unfold.shape)
print(x_fold.shape)
print(x.mean())			# tensor(0.4998)
print(x_fold.mean())	# tensor(0.4866)
print((x[:, :, 30:40, 30:40] == x_fold[:, :, 30:40, 30:40]).sum()) # tensor(189)

由sum()可以看出只有部分数据被还原了。还有一种准确计算divisor(如3. x 3.)的方法是用torch.ones作输入。

k = 5
s = 3
p = (k - s) // 2
H, W = 100, 100

x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)

ones = torch.ones((1, 3, H, W))
ones_unfold = F.unfold(ones, kernel_size=k, stride=s, padding=p)
ones_fold = F.fold(ones_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)

x_fold = x_fold / ones_fold
print(x.mean())			# tensor(0.5001)
print(x_fold.mean())	# tensor(0.5001)
print((x == x_fold).sum())	# tensor(30000) 每个点都被还原了

3. expand

用法Tensor.expand(*size),在size中可以用-1代表保持不变的维度

x = torch.rand((1, 4))	# x = torch.rand(4) 也可以得到同样的结果
x_expand1 = x.expand((3, 4))
x_expand2 = x.expand((3, -1))

print(x)
# tensor([[0.1745, 0.2331, 0.5449, 0.1914]])

print(x_expand1)
#tensor([[0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914]])

print(x_expand2)
#tensor([[0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914]])

4. gather

用法torch.gather(input, dim, index, *, sparse_grad=False, out=None),效果如下

for i in range(dim0):
    for j in range(dim1):
        for k in range(dim2):
            out[i, j, k] = input[index[i][j][k], j, k]  # if dim == 0
			out[i, j, k] = input[i, index[i][j][k], k]  # if dim == 1
			out[i, j, k] = input[i, j, index[i][j][k]]  # if dim == 2

使用gather时首先用expand使index的size与input相等。

index.shape == [B, blocks],用expand将index.shape变为[B,c x c x k,blocks],这样index[i, :, k]是一个1D tensor,且每个元素值都等于expand之前的index[i, j]

如此,当 j 变化时index[i][j][k]就不会变,故循环中的out[i, j, k] = input[i, j, index[i][j][k]]就将 out中的第k个block 和 input中的第index[i][j][k]个block 的每个点一一对应(遍历j)起来。

5. 搭建 Features Transfer

import torch
import torch.nn as nn
import torch.nn.functional as F


class Transfer(nn.Module):
    def __init__(self):
        super(Transfer, self).__init__()

    def bis(self, unfold, dim, index):
        """
        block index select
        args:
            unfold: [B, k*k*C, Hr*Wr]
            dim: 哪个维度是blocks
            index: [B, H*W],  value range is [0, Hr*Wr-1]
            return: [B, k*k*C, H*W]
        """
        views = [unfold.size(0)] + [-1 if i == dim else 1 for i in range(1, len(unfold.size()))]  # [B, 1, -1(H*W)]
        expanse = list(unfold.size())
        expanse[0] = -1
        expanse[dim] = -1   # [-1, k*k*C, -1]
        index = index.view(views).expand(expanse)   # [B, H*W] -> [B, 1, H*W] -> [B, k*k*C, H*W]
        return torch.gather(unfold, dim, index)    # return[i][j][k] = unfold[i][j][index[i][j][k]]

    def forward(self, lrsr_lv3, refsr_lv3, ref_lv1, ref_lv2, ref_lv3):
        """
            args:
                lrsr_lv3: [B, C, H, W]
                refsr_lv3: [B, C, Hr, Wr]
                ref_lv1: [B, C, Hr*4, Wr*4]
                ref_lv2: [B, C, Hr*2, Wr*2]
                ref_lv3: [B, C, Hr, Wr]
        """
        H, W = lrsr_lv3.size()[-2:]

        lrsr_lv3_unfold = F.unfold(lrsr_lv3, kernel_size=3, padding=1, stride=1)    # [B, k*k*C, H*W]
        refsr_lv3_unfold = F.unfold(refsr_lv3, kernel_size=3, padding=1, stride=1).transpose(1, 2)  # [B, Hr*Wr, k*k*C]

        lrsr_lv3_unfold = F.normalize(lrsr_lv3_unfold, dim=1)
        refsr_lv3_unfold = F.normalize(refsr_lv3_unfold, dim=2)

        R = torch.bmm(refsr_lv3_unfold, lrsr_lv3_unfold)  # [B, Hr*Wr, H*W]
        score, index = torch.max(R, dim=1)  # [B, H*W]

        ref_lv3_unfold = F.unfold(ref_lv3, kernel_size=3, padding=1, stride=1)      # vgg19
        ref_lv2_unfold = F.unfold(ref_lv2, kernel_size=6, padding=2, stride=2)      # lv1->lv2, lv2->lv3有一次max pooling
        ref_lv1_unfold = F.unfold(ref_lv1, kernel_size=12, padding=4, stride=4)     # kernel_size没有按照真实的感受野计算

        # 被除数,记录fold(unfold)时的overlap
        divisor_lv3 = F.unfold(torch.ones_like(ref_lv3), kernel_size=3, padding=1, stride=1)
        divisor_lv2 = F.unfold(torch.ones_like(ref_lv2), kernel_size=6, padding=2, stride=2)
        divisor_lv1 = F.unfold(torch.ones_like(ref_lv1), kernel_size=12, padding=4, stride=4)

        T_lv3_unfold = self.bis(ref_lv3_unfold, 2, index)   # [B, k*k*C, H*W]
        T_lv2_unfold = self.bis(ref_lv2_unfold, 2, index)
        T_lv1_unfold = self.bis(ref_lv1_unfold, 2, index)

        divisor_lv3 = self.bis(divisor_lv3, 2, index)  # [B, k*k*C, H*W]
        divisor_lv2 = self.bis(divisor_lv2, 2, index)
        divisor_lv1 = self.bis(divisor_lv1, 2, index)

        divisor_lv3 = F.fold(divisor_lv3, (H, W), kernel_size=3, padding=1, stride=1)
        divisor_lv2 = F.fold(divisor_lv2, (2*H, 2*W), kernel_size=6, padding=2, stride=2)
        divisor_lv1 = F.fold(divisor_lv1, (4*H, 4*W), kernel_size=12, padding=4, stride=4)

        T_lv3 = F.fold(T_lv3_unfold, (H, W), kernel_size=3, padding=1, stride=1) / divisor_lv3
        T_lv2 = F.fold(T_lv2_unfold, (2*H, 2*W), kernel_size=6, padding=2, stride=2) / divisor_lv2
        T_lv1 = F.fold(T_lv1_unfold, (4*H, 4*W), kernel_size=12, padding=4, stride=4) / divisor_lv1

        score = score.view(lrsr_lv3.size(0), 1, H, W)   # [B, 1, H, W]

        return score, T_lv1, T_lv2, T_lv3

**bis中gather的解释:**使用gather时首先用expand使index的size与input相等。

index.shape == [B, blocks],用expand将index.shape变为[B,c x c x k,blocks],这样index[i, :, k]是一个1D tensor,且每个元素值都等于expand之前的index[i, j]

如此,当 j 变化时index[i][j][k]就不会变,故循环中的out[i, j, k] = input[i, j, index[i][j][k]]就将 out中的第k个block 和 input中的第index[i][j][k]个block 的每个点一一对应(遍历j)起来。

参考

https://pytorch.org/docs/stable/index.html

https://github.com/researchmm/TTSR

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch是一个开源的深度学习框架,可以用于搭建深度网络。下面是使用PyTorch搭建深度网络的一般步骤: 1. 导入必要的库和模块:首先,我们需要导入PyTorch库和模块,包括torch、torch.nn和torch.optim。 2. 创建网络模型:使用torch.nn模块定义一个自定义的网络模型类,在这个类中定义网络的结构,包括网络层、激活函数和其他运算。 3. 初始化网络模型:实例化上一步中定义的网络模型类,得到网络模型的对象。 4. 定义损失函数:根据任务的特点选择适当的损失函数,例如分类任务可以使用交叉熵损失函数。 5. 定义优化器:选择合适的优化算法,例如随机梯度下降(SGD)或者Adam优化器等。 6. 训练网络:使用训练数据集对网络模型进行训练。循环遍历训练数据集,将输入数据输入网络模型,得到输出,并与标签进行比较计算损失,然后使用反向传播将损失传递给网络模型,优化模型参数。 7. 测试网络:使用测试数据集对训练好的网络模型进行性能评估。输入测试数据集到网络模型中,得到输出,并与标签进行比较,评估模型的准确率或其他性能指标。 8. 保存和加载模型:可以将训练好的模型保存到文件中,以便后续使用。也可以从文件中加载已经训练好的模型。 以上是使用PyTorch搭建深度网络的基本步骤。在实际应用中,还可以根据具体情况对网络模型进行调参、使用数据增强技术提高模型性能等。通过灵活运用PyTorch的强大功能,可以快速搭建深度网络,并进行训练和评估。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值