一、generate_Data_for_Training
1、初始化
%% Initialization
clear all;
clc;
2、参数设置
角分辨率设置为5,缩放因子(factor)为2/4,每个 SAI patch的空间分辨率为patchsize = factor*32,步长为stride = patchsize/2。
%% Parameters setting
angRes = 5; % Angular Resolution, options, e.g., 3, 5, 7, 9. Default: 5
factor = 2; % SR factor
patchsize = factor*32; % Spatial resolution of each SAI patch
stride = patchsize/2; % stride between two patches. Default: 32
downRatio = 1/factor;
src_data_path = './datasets/';
src_datasets = dir(src_data_path);
src_datasets(1:2) = [];
num_datasets = length(src_datasets);
其中:downRatio = 1/factor:不理解;
src_datasets(1:2) = []:把第一和第二元素取空?
3、训练数据生成
加载LF图像、提取中心5×5视图、生成32*32patch
%% Training data generation
for index_dataset = 1 : num_datasets
idx_save = 0;
name_dataset = src_datasets(index_dataset).name;
src_sub_dataset = [src_data_path, name_dataset, '/training/'];
folders = dir(src_sub_dataset);
folders(1:2) = [];
num_scene = length(folders);
for index_scene = 1 : num_scene
% Load LF image
idx_scene_save = 0;
name_scene = folders(index_scene).name;
name_scene(end-3:end) = [];
fprintf('Generating training data of Scene_%s in Dataset %s......\t\t', name_scene, name_dataset);
data_path = [src_sub_dataset, name_scene];
data = load(data_path);
LF = data.LF;
[U, V, ~, ~, ~] = size(LF);
% Extract central angRes*angRes views
LF = LF(0.5*(U-angRes+2):0.5*(U+angRes), 0.5*(V-angRes+2):0.5*(V+angRes), :, :, 1:3);
[U, V, H, W, ~] = size(LF);
% Generate patches of size 32*32
for h = 1 : stride : H - patchsize + 1
for w = 1 : stride : W - patchsize + 1
idx_save = idx_save + 1;
idx_scene_save = idx_scene_save + 1;
Hr_SAI_y = single(zeros(U * patchsize, V * patchsize));
Lr_SAI_y = single(zeros(U * patchsize * downRatio, V * patchsize * downRatio));
for u = 1 : U
for v = 1 : V
x = (u-1) * patchsize + 1;
y = (v-1) * patchsize + 1;
% Convert to YCbCr
patch_Hr_rgb = double(squeeze(LF(u, v, h : h+patchsize-1, w : w+patchsize-1, :)));
patch_Hr_ycbcr = rgb2ycbcr(patch_Hr_rgb);
patch_Hr_y = squeeze(patch_Hr_ycbcr(:,:,1));
patchsize_Lr = patchsize / factor;
Hr_SAI_y(x:x+patchsize-1, y:y+patchsize-1) = single(patch_Hr_y);
patch_Sr_y = imresize(patch_Hr_y, downRatio);
Lr_SAI_y((u-1)*patchsize_Lr+1 : u*patchsize_Lr, (v-1)*patchsize_Lr+1:v*patchsize_Lr) = single(patch_Sr_y);
end
end
SavePath = ['./data_for_train/SR_', num2str(angRes), 'x' , num2str(angRes), '_' ,num2str(factor), 'x/', name_dataset,'/' ];
if exist(SavePath, 'dir')==0
mkdir(SavePath);
end
SavePath_H5 = [SavePath, num2str(idx_save,'%06d'),'.h5'];
h5create(SavePath_H5, '/Lr_SAI_y', size(Lr_SAI_y), 'Datatype', 'single');
h5write(SavePath_H5, '/Lr_SAI_y', single(Lr_SAI_y), [1,1], size(Lr_SAI_y));
h5create(SavePath_H5, '/Hr_SAI_y', size(Hr_SAI_y), 'Datatype', 'single');
h5write(SavePath_H5, '/Hr_SAI_y', single(Hr_SAI_y), [1,1], size(Hr_SAI_y));
end
end
fprintf([num2str(idx_scene_save), ' training samples have been generated\n']);
end
end
二、generate_Data_for_Test
1、初始化
%% Initialization
clear all;
clc;
2、参数设置
%% Parameters setting
angRes = 5; % Angular Resolution, options, e.g., 3, 5, 7, 9. Default: 5
factor = 2; % SR factor
downRatio = 1/factor;
src_data_path = './datasets/';
src_datasets = dir(src_data_path);
src_datasets(1:2) = [];
num_datasets = length(src_datasets);
3、测试数据生成
加载LF图像、提取中心5×5视图、转换为 YCbCr???
%% Test data generation
for index_dataset = 1 : num_datasets
idx_save = 0;
name_dataset = src_datasets(index_dataset).name;
src_sub_dataset = [src_data_path, name_dataset, '/test/'];
scenes = dir(src_sub_dataset);
scenes(1:2) = [];
num_scene = length(scenes);
for index_scene = 1 : num_scene
% Load LF image
idx_scene_save = 0;
name_scene = scenes(index_scene).name;
name_scene(end-3:end) = [];
fprintf('Generating test data of Scene_%s in Dataset %s......\t\t', name_scene, src_datasets(index_dataset).name);
data_path = [src_sub_dataset, name_scene];
data = load(data_path);
LF = data.LF;
[U, V, H, W, ~] = size(LF);
while mod(H, 4) ~= 0
H = H - 1;
end
while mod(W, 4) ~= 0
W = W - 1;
end
% Extract central angRes*angRes views
LF = LF(0.5*(U-angRes+2):0.5*(U+angRes), 0.5*(V-angRes+2):0.5*(V+angRes), 1:H, 1:W, 1:3); % Extract central angRes*angRes views
[U, V, H, W, ~] = size(LF);
% Convert to YCbCr
idx_save = idx_save + 1;
idx_scene_save = idx_scene_save + 1;
Hr_SAI_y = single(zeros(U * H, V * W));
Lr_SAI_y = single(zeros(U * H * downRatio, V * W * downRatio));
for u = 1 : U
for v = 1 : V
x = (u-1)*H+1;
y = (v-1)*W+1;
temp_Hr_rgb = double(squeeze(LF(u, v, :, :, :)));
temp_Hr_ycbcr = rgb2ycbcr(temp_Hr_rgb);
Hr_SAI_y(x:u*H, y:v*W) = single(temp_Hr_ycbcr(:,:,1));
temp_Hr_y = squeeze(temp_Hr_ycbcr(:,:,1));
temp_Lr_y = imresize(temp_Hr_y, downRatio);
Lr_SAI_y((u-1)*H*downRatio+1 : u*H*downRatio, (v-1)*W*downRatio+1:v*W*downRatio) = single(temp_Lr_y);
end
end
SavePath = ['./data_for_test/SR_', num2str(angRes), 'x' , num2str(angRes), '_' ,num2str(factor), 'x/', name_dataset,'/' ];
if exist(SavePath, 'dir')==0
mkdir(SavePath);
end
SavePath_H5 = [SavePath, name_scene,'.h5'];
h5create(SavePath_H5, '/Hr_SAI_y', size(Hr_SAI_y), 'Datatype', 'single');
h5write(SavePath_H5, '/Hr_SAI_y', single(Hr_SAI_y), [1,1], size(Hr_SAI_y));
h5create(SavePath_H5, '/Lr_SAI_y', size(Lr_SAI_y), 'Datatype', 'single');
h5write(SavePath_H5, '/Lr_SAI_y', single(Lr_SAI_y), [1,1], size(Lr_SAI_y));
fprintf([num2str(idx_scene_save), ' test samples have been generated\n']);
end
end
三、train.py
1、导入模块
from torch.utils.data import DataLoader
import importlib
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from utils.utils import *
from utils.utils_datasets import TrainSetDataLoader
from collections import OrderedDict
2、定义主函数:
def main(args):
为保存创建目录:实验目录、检查点目录、日志目录
''' Create Dir for Save '''
experiment_dir, checkpoints_dir, log_dir = create_dir(args)
记录器:
''' Logger '''
logger = Logger(log_dir, args)
设置CPU或GPU
''' CPU or Cuda '''
torch.cuda.set_device(args.local_rank)
# device = torch.device("cuda", args.local_rank)
device = torch.device("cpu", args.local_rank)
数据训练加载
''' DATA TRAINING LOADING '''
logger.log_string('\nLoad Training Dataset ...')
train_Dataset = TrainSetDataLoader(args)
logger.log_string("The number of training data is: %d" % len(train_Dataset))
train_loader = torch.utils.data.DataLoader(dataset=train_Dataset, num_workers=args.num_workers,
batch_size=args.batch_size, shuffle=True,)
模型加载
''' MODEL LOADING '''
logger.log_string('\nModel Initial ...')
MODEL_PATH = 'model.' + args.model_name
MODEL = importlib.import_module(MODEL_PATH)
net = MODEL.get_model(args)
加载预训练模型
''' load pre-trained pth '''
if args.use_pre_pth == False:
net.apply(MODEL.weights_init)
start_epoch = 0
logger.log_string('Do not use pretrain model!')
else:
try:
ckpt_path = args.path_pre_pth
checkpoint = torch.load(ckpt_path, map_location='cpu')
start_epoch = checkpoint['epoch']
try:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = 'module.' + k # add `module.`
new_state_dict[name] = v
# load params
net.load_state_dict(new_state_dict)
logger.log_string('Use pretrain model!')
except:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
new_state_dict[k] = v
# load params
net.load_state_dict(new_state_dict)
logger.log_string('Use pretrain model!')
except:
net.apply(MODEL.weights_init)
start_epoch = 0
logger.log_string('No existing model, starting training from scratch...')
pass
pass
net = net.to(device)
cudnn.benchmark = True
打印参数
''' Print Parameters '''
logger.log_string('PARAMETER ...')
logger.log_string(args)
损失加载
'''LOSS LOADING '''
criterion = MODEL.get_loss(args).to(device)
优化器(Adam)
optimizer = torch.optim.Adam(
[paras for paras in net.parameters() if paras.requires_grad == True],
lr=args.lr,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=args.decay_rate
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.n_steps, gamma=args.gamma)
训练
''' TRAINING '''
logger.log_string('\nStart training...')
for idx_epoch in range(start_epoch, args.epoch):
logger.log_string('\nEpoch %d /%s:' % (idx_epoch + 1, args.epoch))
loss_epoch_train, psnr_epoch_train, ssim_epoch_train = train(train_loader, device, net, criterion, optimizer)
logger.log_string('The %dth Train, loss is: %.5f, psnr is %.5f, ssim is %.5f' %
(idx_epoch + 1, loss_epoch_train, psnr_epoch_train, ssim_epoch_train))
# save model
if args.local_rank == 0:
save_ckpt_path = str(checkpoints_dir) + '/%s_%dx%d_%dx_epoch_%02d_model.pth' % (
args.model_name, args.angRes, args.angRes, args.scale_factor, idx_epoch + 1)
state = {
'epoch': idx_epoch + 1,
'state_dict': net.module.state_dict() if hasattr(net, 'module') else net.state_dict(),
}
torch.save(state, save_ckpt_path)
logger.log_string('Saving the epoch_%02d model at %s' % (idx_epoch + 1, save_ckpt_path))
''' scheduler '''
scheduler.step()
pass
pass
3、定义train函数
def train(train_loader,device,net,criterion,optimizer):
训练一个epoch(1个epoch表示过了1遍训练集中的所有样本。)
'''training one epoch'''
psnr_iter_train = []
loss_iter_train = []
ssim_iter_train = []
args.temperature = 1.0
for idx_iter, (data, label) in tqdm(enumerate(train_loader), total=len(train_loader), ncols=70):
data = data.to(device) # low resolution
label = label.to(device) # high resolution
out = net(data)
loss = criterion(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.cuda.empty_cache()
loss_iter_train.append(loss.data.cpu())
psnr, ssim = cal_metrics(args, label, out)
psnr_iter_train.append(psnr)
ssim_iter_train.append(ssim)
pass
loss_epoch_train = float(np.array(loss_iter_train).mean())
psnr_epoch_train = float(np.array(psnr_iter_train).mean())
ssim_epoch_train = float(np.array(ssim_iter_train).mean())
return loss_epoch_train, psnr_epoch_train, ssim_epoch_train
4、控制代码
if __name__ == '__main__':
from option import args
main(args)
四、test.py
1、导入模块
from torch.utils.data import DataLoader
import importlib
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from utils.utils import *
from utils.utils_datasets import MultiTestSetDataLoader
from collections import OrderedDict
2、定义主函数:
def main(args):
为保存创建目录
''' Create Dir for Save'''
experiment_dir, checkpoints_dir, log_dir = create_dir(args)
记录器:
''' Logger '''
logger = Logger(log_dir, args)
设置CPU或GPU
''' CPU or Cuda '''
torch.cuda.set_device(args.local_rank)
# device = torch.device("cuda", args.local_rank)
device = torch.device("cpu", args.local_rank)
数据训练加载
''' DATA TEST LOADING '''
logger.log_string('\nLoad Test Dataset ...')
test_Names, test_Loaders, length_of_tests = MultiTestSetDataLoader(args)
logger.log_string("The number of test data is: %d" % length_of_tests)
模型加载
''' MODEL LOADING '''
logger.log_string('\nModel Initial ...')
MODEL_PATH = 'model.' + args.model_name
MODEL = importlib.import_module(MODEL_PATH)
net = MODEL.get_model(args)
加载预训练模型
''' load pre-trained pth '''
ckpt_path = args.path_pre_pth
checkpoint = torch.load(ckpt_path, map_location='cpu')
start_epoch = checkpoint['epoch']
try:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = 'module.' + k # add `module.`
new_state_dict[name] = v
# load params
net.load_state_dict(new_state_dict)
logger.log_string('Use pretrain model!')
except:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
new_state_dict[k] = v
# load params
net.load_state_dict(new_state_dict)
logger.log_string('Use pretrain model!')
net = net.to(device)
cudnn.benchmark = True
测试每个数据集
''' TEST on every dataset'''
logger.log_string('\nStart test...')
with torch.no_grad():
psnr_testset = []
ssim_testset = []
for index, test_name in enumerate(test_Names):
test_loader = test_Loaders[index]
psnr_epoch_test, ssim_epoch_test = test(test_loader, device, net)
psnr_testset.append(psnr_epoch_test)
ssim_testset.append(ssim_epoch_test)
logger.log_string('Test on %s, psnr/ssim is %.2f/%.3f' % (test_name, psnr_epoch_test, ssim_epoch_test))
pass
pass
pass
3、定义 test函数
def test(test_loader, device, net):
psnr_iter_test = []
ssim_iter_test = []
for idx_iter, (Lr_SAI_y, Hr_SAI_y) in tqdm(enumerate(test_loader), total=len(test_loader), ncols=70):
Lr_SAI_y = Lr_SAI_y.squeeze().to(device) # numU, numV, h*angRes, w*angRes
Hr_SAI_y = Hr_SAI_y.squeeze()
uh, vw = Lr_SAI_y.shape
h0, w0 = int(uh//args.angRes), int(vw//args.angRes)
subLFin = LFdivide(Lr_SAI_y, args.angRes, args.patch_size_for_test, args.stride_for_test)
numU, numV, H, W = subLFin.size()
subLFout = torch.zeros(numU, numV, args.angRes * args.patch_size_for_test * args.scale_factor,
args.angRes * args.patch_size_for_test * args.scale_factor)
for u in range(numU):
for v in range(numV):
tmp = subLFin[u:u+1, v:v+1, :, :]
with torch.no_grad():
net.eval()
torch.cuda.empty_cache()
out = net(tmp.to(device))
subLFout[u:u+1, v:v+1, :, :] = out.squeeze()
Sr_4D_y = LFintegrate(subLFout, args.angRes, args.patch_size_for_test * args.scale_factor,
args.stride_for_test * args.scale_factor, h0 * args.scale_factor,
w0 * args.scale_factor)
Sr_SAI_y = Sr_4D_y.permute(0, 2, 1, 3).reshape((h0 * args.angRes * args.scale_factor,
w0 * args.angRes * args.scale_factor))
psnr, ssim = cal_metrics(args, Hr_SAI_y, Sr_SAI_y)
psnr_iter_test.append(psnr)
ssim_iter_test.append(ssim)
pass
psnr_epoch_test = float(np.array(psnr_iter_test).mean())
ssim_epoch_test = float(np.array(ssim_iter_test).mean())
return psnr_epoch_test, ssim_epoch_test
5、控制代码
if __name__ == '__main__':
from option import args
main(args)
五、LTF.py module
1、导入模块
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math
2、定义get_module类:
class get_model(nn.Module):
初始化
实例属性:(通道数、缩放因子、层数=4、位置编码、MHSA。。。)、初始卷积、Alternate AngTrans & SpaTrans ???、上采样
def __init__(self, args):
super(get_model, self).__init__()
channels = args.channels
self.channels = channels
self.angRes = args.angRes
self.factor = args.scale_factor
layer_num = 4
self.pos_encoding = PositionEncoding(temperature=10000)
self.MHSA_params = {}
self.MHSA_params['num_heads'] = 8
self.MHSA_params['dropout'] = 0.
##################### Initial Convolution #####################
self.conv_init0 = nn.Sequential(
nn.Conv3d(1, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1), dilation=1, bias=False),
)
self.conv_init = nn.Sequential(
nn.Conv3d(channels, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1), dilation=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(channels, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1), dilation=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(channels, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1), dilation=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
)
################ Alternate AngTrans & SpaTrans ################
self.altblock = self.make_layer(layer_num=layer_num)
####################### UP Sampling ###########################
self.upsampling = nn.Sequential(
nn.Conv2d(channels, channels*self.factor ** 2, kernel_size=1, padding=0, dilation=1, bias=False),
nn.PixelShuffle(self.factor),
nn.LeakyReLU(0.2),
nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=False),
)
make_layer函数
def make_layer(self, layer_num):
layers = []
for i in range(layer_num):
layers.append(AltFilter(self.angRes, self.channels, self.MHSA_params))
return nn.Sequential(*layers)
forward函数:
双三次插值、reshape、初始卷积、位置编码、Alternate AngTrans&SpaTrans、上采样。
def forward(self, lr):
# Bicubic
lr_upscale = interpolate(lr, self.angRes, scale_factor=self.factor, mode='bicubic')
# [B(atch), 1, A(ngRes)*h(eight)*S(cale), A(ngRes)*w(idth)*S(cale)]
# reshape for LFT
lr = rearrange(lr, 'b c (a1 h) (a2 w) -> b c (a1 a2) h w', a1=self.angRes, a2=self.angRes)
# [B, C(hannels), A^2, h, w]
for m in self.modules():
m.h = lr.size(-2)
m.w = lr.size(-1)
# Initial Convolution
buffer = self.conv_init0(lr)
buffer = self.conv_init(buffer) + buffer # [B, C, A^2, h, w]
# Position Encoding
spa_position = self.pos_encoding(buffer, dim=[3, 4], token_dim=self.channels)
ang_position = self.pos_encoding(buffer, dim=[2], token_dim=self.channels)
for m in self.modules():
m.spa_position = spa_position
m.ang_position = ang_position
# Alternate AngTrans & SpaTrans
buffer = self.altblock(buffer) + buffer
# Up-Sampling
buffer = rearrange(buffer, 'b c (a1 a2) h w -> b c (a1 h) (a2 w)', a1=self.angRes, a2=self.angRes)
buffer = self.upsampling(buffer)
out = buffer + lr_upscale
return out
#上下采样函数interpolate() Pytorch上下采样函数--interpolate()_Activewaste的博客-CSDN博客_interpolate()
#rearrange() einops库中rearrange,reduce和repeat的介绍_鬼道2022的博客-CSDN博客_einops.rearrange
3、位置编码类:
class PositionEncoding(nn.Module):
初始化实例属性
def __init__(self, temperature):
super(PositionEncoding, self).__init__()
self.temperature = temperature
forward函数
def forward(self, x, dim: list, token_dim):
self.token_dim = token_dim
assert len(x.size()) == 5, 'the object of position encoding requires 5-dim tensor! '
grid_dim = torch.linspace(0, self.token_dim - 1, self.token_dim, dtype=torch.float32)
grid_dim = 2 * (grid_dim // 2) / self.token_dim
grid_dim = self.temperature ** grid_dim
position = None
for index in range(len(dim)):
pos_size = [1, 1, 1, 1, 1, self.token_dim]
length = x.size(dim[index])
pos_size[dim[index]] = length
pos_dim = (torch.linspace(0, length - 1, length, dtype=torch.float32).view(-1, 1) / grid_dim).to(x.device)
pos_dim = torch.cat([pos_dim[:, 0::2].sin(), pos_dim[:, 1::2].cos()], dim=1)
pos_dim = pos_dim.view(pos_size)
if position is None:
position = pos_dim
else:
position = position + pos_dim
pass
position = rearrange(position, 'b 1 a h w dim -> b dim a h w')
return position / len(dim)
torch.linspace() torch.linspace()用法_快乐地笑的博客-CSDN博客
4、定义SpanTrans类:
class SpaTrans(nn.Module):
初始化实例属性
def __init__(self, channels, angRes, MHSA_params):
super(SpaTrans, self).__init__()
self.angRes = angRes
self.kernel_field = 3
self.kernel_search = 5
self.spa_dim = channels * 2
self.MLP = nn.Linear(channels * self.kernel_field ** 2, self.spa_dim, bias=False)
self.norm = nn.LayerNorm(self.spa_dim)
self.attention = nn.MultiheadAttention(self.spa_dim,
MHSA_params['num_heads'],
MHSA_params['dropout'],
bias=False)
nn.init.kaiming_uniform_(self.attention.in_proj_weight, a=math.sqrt(5))
self.attention.out_proj.bias = None
self.feed_forward = nn.Sequential(
nn.LayerNorm(self.spa_dim),
nn.Linear(self.spa_dim, self.spa_dim*2, bias=False),
nn.ReLU(True),
nn.Dropout(MHSA_params['dropout']),
nn.Linear(self.spa_dim*2, self.spa_dim, bias=False),
nn.Dropout(MHSA_params['dropout'])
)
self.linear = nn.Sequential(
nn.Conv3d(self.spa_dim, channels, kernel_size=(1, 1, 1), padding=(0, 0, 0), dilation=1, bias=False),
)
nn.Linear() nn.Linear()函数详解及代码使用_墨晓白的博客-CSDN博客_nn.linear
nn.LayerNorm() nn.LayerNorm的实现及原理_harry_tea的博客-CSDN博客_layer norm
定义gen_mask函数,并使用 @staticmethod
@staticmethod
def gen_mask(h:int, w:int, k:int):
atten_mask = torch.zeros([h, w, h, w])
k_left = k//2
k_right = k - k_left
for i in range(h):
for j in range(w):
temp = torch.zeros(h, w)
temp[max(0, i-k_left):min(h,i+k_right), max(0, j-k_left):min(h,j+k_right)] = 1
atten_mask[i, j, :, :] = temp
atten_mask = rearrange(atten_mask, 'a b c d -> (a b) (c d)')
atten_mask = atten_mask.float().masked_fill(atten_mask == 0, float('-inf')).\
masked_fill(atten_mask == 1, float(0.0))
return atten_mask
定义SAI2Token函数
def SAI2Token(self, buffer):
buffer = rearrange(buffer, 'b c a h w -> (b a) c h w')
# local feature embedding
spa_token = F.unfold(buffer, kernel_size=self.kernel_field, padding=self.kernel_field//2).permute(2, 0, 1)
spa_token = self.MLP(spa_token)
return spa_token
定义Token2SAI函数
def Token2SAI(self, buffer_token_spa):
buffer = rearrange(buffer_token_spa, '(h w) (b a) c -> b c a h w', h=self.h, w=self.w, a=self.angRes**2)
buffer = self.linear(buffer)
return buffer
定义foward函数
def forward(self, buffer):
ang_token = self.SAI2Token(buffer)
ang_PE = self.SAI2Token(self.ang_position)
ang_token_norm = self.norm(ang_token + ang_PE)
ang_token = self.attention(query=ang_token_norm,
key=ang_token_norm,
value=ang_token,
need_weights=False)[0] + ang_token
ang_token = self.feed_forward(ang_token) + ang_token
buffer = self.Token2SAI(ang_token)
return buffer
5、定义AltFilter类
class AltFilter(nn.Module):
def __init__(self, angRes, channels, MHSA_params):
super(AltFilter, self).__init__()
self.angRes = angRes
self.spa_trans = SpaTrans(channels, angRes, MHSA_params)
self.ang_trans = AngTrans(channels, angRes, MHSA_params)
def forward(self, buffer):
buffer = self.ang_trans(buffer)
buffer = self.spa_trans(buffer)
return buffer
定义interpolate函数
def interpolate(x, angRes, scale_factor, mode):
[B, _, H, W] = x.size()
h = H // angRes
w = W // angRes
x_upscale = x.view(B, 1, angRes, h, angRes, w)
x_upscale = x_upscale.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * angRes ** 2, 1, h, w)
x_upscale = F.interpolate(x_upscale, scale_factor=scale_factor, mode=mode, align_corners=False)
x_upscale = x_upscale.view(B, angRes, angRes, 1, h * scale_factor, w * scale_factor)
x_upscale = x_upscale.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, 1, H * scale_factor, W * scale_factor)
# [B, 1, A*h*S, A*w*S]
return x_upscale
6、定义get_loss函数
class get_loss(nn.Module):
def __init__(self, args):
super(get_loss, self).__init__()
self.criterion_Loss = torch.nn.L1Loss()
def forward(self, SR, HR):
loss = self.criterion_Loss(SR, HR)
return loss
7、weight_init函数
def weights_init(m):
pass