一、如何让模型可有多种输入?
情况说明
本人是在WNO案例中,做超分辨率任务重涉及到的,这个模型可以输入【20,1024,1】,也可以输入【20,2048,1】。对应的是batch-size=20,一维数据1024那么是如何实现的呢?
解释:
原因:该模型之所以能够处理不同尺寸的输入,归功于位置编码的动态生成以及卷积操作对输入长度的不敏感性(就是forward和get_grid函数)。无论是 (20, 1024, 1) 还是 (20, 2048, 1), 模型都能通过相同的处理步骤进行前向传播和预测。
对代码具体来说:x输入shape是【20,1024,1】,经过forward函数后输出shape依旧是【20,1024,1】。其中1024是根据参数x的长度动态变化的,所以能接受任何【20,x,1】的输入。
def forward(self, x):#技巧:x输入是【20,1024,1】,经过forward函数后输出依旧是【20,1024,1】。其中1024是根据参数x的长度动态变化的,所以能接受任何【20,x,1】的输入。
grid = self.get_grid(x.shape, x.device)#相当于给【20,1024,1】加上位置编码,变成【20,1024,2】
x = torch.cat((x, grid), dim=-1)#上面数据是一维1024个点,一个batch20个样本。cat连接后这里执行完变成【20,1024,2】
x = self.fc0(x) #【20,1024,64】 # Shape: Batch * x * Channel】 #fc0:Linear(in_features=2, out_features=64, bias=True)
x = x.permute(0, 2, 1) # Shape: Batch * Channel * x #【20,64,1024】
if self.padding != 0:
x = F.pad(x, [0,self.padding])
for index, (convl, wl) in enumerate( zip(self.conv, self.w) ):
x = convl(x) + wl(x)
if index != self.layers - 1: # Final layer has no activation
x = F.mish(x) # Shape: Batch * Channel * x
if self.padding != 0:
x = x[..., :-self.padding]
x = x.permute(0, 2, 1) # Shape: Batch * x * Channel
x = F.gelu( self.fc1(x) ) # Shape: Batch * x * Channel
x = self.fc2(x) # Shape: Batch * x * Channel
return x#输出依旧是【20,1024,1】
def get_grid(self, shape, device):#get_grid 函数根据输入的长度动态生成位置编码:
# The grid of the solution
batchsize, size_x = shape[0], shape[1]#size_x = 1024
gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)#相当于生成位置编码,将【0-1】】生成了1024份
gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
return gridx.to(device)
二、如何加入位置编码:
具体是下边这个函数生成了一个与一维数据1024一样长的位置数,这里用的是np.linspace函数,范围是0-1,然后分为1024个数直接作为位置编码然后与x通过cat函数连接成【1024,2】,相当于有1024个像素点,每个点有2个值表示,也就是2通道。
def get_grid(self, shape, device):#get_grid 函数根据输入的长度动态生成位置编码:
# The grid of the solution
batchsize, size_x = shape[0], shape[1]#size_x = 1024
gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)#相当于生成位置编码,将【0-1】】生成了1024份
gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
return gridx.to(device)
然后再forward函数中cat连接
def forward(self, x):#技巧:x输入是【20,1024,1】,经过forward函数后输出依旧是【20,1024,1】。其中1024是根据参数x的长度动态变化的,所以能接受任何【20,x,1】的输入。
grid = self.get_grid(x.shape, x.device)#相当于给【20,1024,1】加上位置编码,变成【20,1024,2】
x = torch.cat((x, grid), dim=-1)#上面数据是一维1024个点,一个batch20个样本。cat连接后这里执行完变成【20,1024,2】
加入位置编码可以增加准确率。
案例代码是:
GitHub - csccm-iitd/WNO 中的Test_wno_super_1d_Burgers.py案例