深度学习技巧:在深度学习中,模型输入为什么有多种输入shape,可以输入【20,1024,1】,也可以输入【20,2048,1】这个是如何实现的?另外提供一种加入位置编码提高模型准确率的方法

一、如何让模型可有多种输入?

情况说明

本人是在WNO案例中,做超分辨率任务重涉及到的,这个模型可以输入【20,1024,1】,也可以输入【20,2048,1】。对应的是batch-size=20,一维数据1024那么是如何实现的呢?

解释:

原因:该模型之所以能够处理不同尺寸的输入,归功于位置编码的动态生成以及卷积操作对输入长度的不敏感性(就是forward和get_grid函数)。无论是 (20, 1024, 1) 还是 (20, 2048, 1), 模型都能通过相同的处理步骤进行前向传播和预测。

对代码具体来说:x输入shape是【20,1024,1】,经过forward函数后输出shape依旧是【20,1024,1】。其中1024是根据参数x的长度动态变化的,所以能接受任何【20,x,1】的输入。

    def forward(self, x):#技巧:x输入是【20,1024,1】,经过forward函数后输出依旧是【20,1024,1】。其中1024是根据参数x的长度动态变化的,所以能接受任何【20,x,1】的输入。
        grid = self.get_grid(x.shape, x.device)#相当于给【20,1024,1】加上位置编码,变成【20,1024,2】
        x = torch.cat((x, grid), dim=-1)#上面数据是一维1024个点,一个batch20个样本。cat连接后这里执行完变成【20,1024,2】
        x = self.fc0(x)   #【20,1024,64】           # Shape: Batch * x * Channel】  #fc0:Linear(in_features=2, out_features=64, bias=True)
        x = x.permute(0, 2, 1)       # Shape: Batch * Channel * x   #【20,64,1024】
        if self.padding != 0:
            x = F.pad(x, [0,self.padding]) 
        
        for index, (convl, wl) in enumerate( zip(self.conv, self.w) ):
            x = convl(x) + wl(x) 
            if index != self.layers - 1:   # Final layer has no activation    
                x = F.mish(x)        # Shape: Batch * Channel * x 
                
        if self.padding != 0:
            x = x[..., :-self.padding] 
        x = x.permute(0, 2, 1)       # Shape: Batch * x * Channel
        x = F.gelu( self.fc1(x) )    # Shape: Batch * x * Channel
        x = self.fc2(x)              # Shape: Batch * x * Channel
        return x#输出依旧是【20,1024,1】

    def get_grid(self, shape, device):#get_grid 函数根据输入的长度动态生成位置编码:
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]#size_x = 1024
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)#相当于生成位置编码,将【0-1】】生成了1024份
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)

二、如何加入位置编码:

具体是下边这个函数生成了一个与一维数据1024一样长的位置数,这里用的是np.linspace函数,范围是0-1,然后分为1024个数直接作为位置编码然后与x通过cat函数连接成【1024,2】,相当于有1024个像素点,每个点有2个值表示,也就是2通道。

    def get_grid(self, shape, device):#get_grid 函数根据输入的长度动态生成位置编码:
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]#size_x = 1024
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)#相当于生成位置编码,将【0-1】】生成了1024份
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)

然后再forward函数中cat连接

    def forward(self, x):#技巧:x输入是【20,1024,1】,经过forward函数后输出依旧是【20,1024,1】。其中1024是根据参数x的长度动态变化的,所以能接受任何【20,x,1】的输入。
        grid = self.get_grid(x.shape, x.device)#相当于给【20,1024,1】加上位置编码,变成【20,1024,2】
        x = torch.cat((x, grid), dim=-1)#上面数据是一维1024个点,一个batch20个样本。cat连接后这里执行完变成【20,1024,2】

加入位置编码可以增加准确率。

案例代码是:

GitHub - csccm-iitd/WNO 中的Test_wno_super_1d_Burgers.py案例

  • 11
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值