Light Field Image Super-Resolution withTransformers代码

一、generate_Data_for_Training

1、初始化

%% Initialization
clear all;
clc;

2、参数设置

角分辨率设置为5,缩放因子(factor)为2/4,每个 SAI patch的空间分辨率为patchsize = factor*32,步长为stride = patchsize/2。

%% Parameters setting
angRes = 5;                 % Angular Resolution, options, e.g., 3, 5, 7, 9. Default: 5
factor = 2;                 % SR factor
patchsize = factor*32;  	% Spatial resolution of each SAI patch
stride = patchsize/2;       % stride between two patches. Default: 32
downRatio = 1/factor;                            
src_data_path = './datasets/';
src_datasets = dir(src_data_path);
src_datasets(1:2) = [];
num_datasets = length(src_datasets); 

其中:downRatio = 1/factor:不理解;

src_datasets(1:2) = []:把第一和第二元素取空?

3、训练数据生成

加载LF图像、提取中心5×5视图、生成32*32patch

%% Training data generation
for index_dataset = 1 : num_datasets
    idx_save = 0;
    name_dataset = src_datasets(index_dataset).name;
    src_sub_dataset = [src_data_path, name_dataset, '/training/'];
    folders = dir(src_sub_dataset);
    folders(1:2) = [];
    num_scene = length(folders); 
    
    for index_scene = 1 : num_scene 
        % Load LF image
        idx_scene_save = 0;
        name_scene = folders(index_scene).name;
        name_scene(end-3:end) = [];
        fprintf('Generating training data of Scene_%s in Dataset %s......\t\t', name_scene, name_dataset);
        data_path = [src_sub_dataset, name_scene];
        data = load(data_path);
        LF = data.LF; 
        [U, V, ~, ~, ~] = size(LF);
         
        % Extract central angRes*angRes views
        LF = LF(0.5*(U-angRes+2):0.5*(U+angRes), 0.5*(V-angRes+2):0.5*(V+angRes), :, :, 1:3); 
        [U, V, H, W, ~] = size(LF);
                
        % Generate patches of size 32*32
        for h = 1 : stride : H - patchsize + 1
            for w = 1 : stride : W - patchsize + 1
                idx_save = idx_save + 1;
                idx_scene_save = idx_scene_save + 1;
                Hr_SAI_y = single(zeros(U * patchsize, V * patchsize));
                Lr_SAI_y = single(zeros(U * patchsize * downRatio, V * patchsize * downRatio));             

                for u = 1 : U
                    for v = 1 : V     
                        x = (u-1) * patchsize + 1;
                        y = (v-1) * patchsize + 1;
                        
                        % Convert to YCbCr
                        patch_Hr_rgb = double(squeeze(LF(u, v, h : h+patchsize-1, w : w+patchsize-1, :)));
                        patch_Hr_ycbcr = rgb2ycbcr(patch_Hr_rgb);
                        patch_Hr_y = squeeze(patch_Hr_ycbcr(:,:,1)); 
                                                
                        patchsize_Lr = patchsize / factor;
                        Hr_SAI_y(x:x+patchsize-1, y:y+patchsize-1) = single(patch_Hr_y);
                        patch_Sr_y = imresize(patch_Hr_y, downRatio);
                        Lr_SAI_y((u-1)*patchsize_Lr+1 : u*patchsize_Lr, (v-1)*patchsize_Lr+1:v*patchsize_Lr) = single(patch_Sr_y);
         
                    end
                end

                SavePath = ['./data_for_train/SR_', num2str(angRes), 'x' , num2str(angRes), '_' ,num2str(factor), 'x/', name_dataset,'/' ];
                if exist(SavePath, 'dir')==0
                    mkdir(SavePath);
                end

                SavePath_H5 = [SavePath, num2str(idx_save,'%06d'),'.h5'];
                
                h5create(SavePath_H5, '/Lr_SAI_y', size(Lr_SAI_y), 'Datatype', 'single');
                h5write(SavePath_H5, '/Lr_SAI_y', single(Lr_SAI_y), [1,1], size(Lr_SAI_y));
                
                h5create(SavePath_H5, '/Hr_SAI_y', size(Hr_SAI_y), 'Datatype', 'single');
                h5write(SavePath_H5, '/Hr_SAI_y', single(Hr_SAI_y), [1,1], size(Hr_SAI_y));
                
            end
        end
        fprintf([num2str(idx_scene_save), ' training samples have been generated\n']);
    end
end

二、generate_Data_for_Test

1、初始化

%% Initialization
clear all;
clc;

2、参数设置

%% Parameters setting
angRes = 5;                 % Angular Resolution, options, e.g., 3, 5, 7, 9. Default: 5
factor = 2;                 % SR factor
downRatio = 1/factor;
src_data_path = './datasets/';
src_datasets = dir(src_data_path);
src_datasets(1:2) = [];
num_datasets = length(src_datasets); 

3、测试数据生成

加载LF图像、提取中心5×5视图、转换为 YCbCr???

%% Test data generation
for index_dataset = 1 : num_datasets 
    idx_save = 0;
    name_dataset = src_datasets(index_dataset).name;
    src_sub_dataset = [src_data_path, name_dataset, '/test/'];
    scenes = dir(src_sub_dataset);
    scenes(1:2) = [];
    num_scene = length(scenes); 
    
    for index_scene = 1 : num_scene 
        % Load LF image
        idx_scene_save = 0;
        name_scene = scenes(index_scene).name;
        name_scene(end-3:end) = [];
        fprintf('Generating test data of Scene_%s in Dataset %s......\t\t', name_scene, src_datasets(index_dataset).name);
        data_path = [src_sub_dataset, name_scene];
        data = load(data_path);
        LF = data.LF;
        [U, V, H, W, ~] = size(LF);
        while mod(H, 4) ~= 0
            H = H - 1;
        end
        while mod(W, 4) ~= 0
            W = W - 1;
        end
        
        % Extract central angRes*angRes views
        LF = LF(0.5*(U-angRes+2):0.5*(U+angRes), 0.5*(V-angRes+2):0.5*(V+angRes), 1:H, 1:W, 1:3); % Extract central angRes*angRes views
        [U, V, H, W, ~] = size(LF);
    
        % Convert to YCbCr
        idx_save = idx_save + 1;
        idx_scene_save = idx_scene_save + 1;
        Hr_SAI_y = single(zeros(U * H, V * W));
        Lr_SAI_y = single(zeros(U * H * downRatio, V * W * downRatio));           
   
        for u = 1 : U
            for v = 1 : V
                x = (u-1)*H+1;
                y = (v-1)*W+1;
                
                temp_Hr_rgb = double(squeeze(LF(u, v, :, :, :)));
                temp_Hr_ycbcr = rgb2ycbcr(temp_Hr_rgb);
                Hr_SAI_y(x:u*H, y:v*W) = single(temp_Hr_ycbcr(:,:,1));
                
                temp_Hr_y = squeeze(temp_Hr_ycbcr(:,:,1));
                temp_Lr_y = imresize(temp_Hr_y, downRatio);
                Lr_SAI_y((u-1)*H*downRatio+1 : u*H*downRatio, (v-1)*W*downRatio+1:v*W*downRatio) = single(temp_Lr_y);                  
            end
        end 
        
        SavePath = ['./data_for_test/SR_', num2str(angRes), 'x' , num2str(angRes), '_' ,num2str(factor), 'x/', name_dataset,'/' ];
        if exist(SavePath, 'dir')==0
            mkdir(SavePath);
        end

        SavePath_H5 = [SavePath, name_scene,'.h5'];
        
        h5create(SavePath_H5, '/Hr_SAI_y', size(Hr_SAI_y), 'Datatype', 'single');
        h5write(SavePath_H5, '/Hr_SAI_y', single(Hr_SAI_y), [1,1], size(Hr_SAI_y));
        
        h5create(SavePath_H5, '/Lr_SAI_y', size(Lr_SAI_y), 'Datatype', 'single');
        h5write(SavePath_H5, '/Lr_SAI_y', single(Lr_SAI_y), [1,1], size(Lr_SAI_y));

        fprintf([num2str(idx_scene_save), ' test samples have been generated\n']);
    end
end

三、train.py
1、导入模块

from torch.utils.data import DataLoader
import importlib
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from utils.utils import *
from utils.utils_datasets import TrainSetDataLoader
from collections import OrderedDict

2、定义主函数:

def main(args):

为保存创建目录:实验目录、检查点目录、日志目录

''' Create Dir for Save '''
    experiment_dir, checkpoints_dir, log_dir = create_dir(args)

记录器:

''' Logger '''
    logger = Logger(log_dir, args)

设置CPU或GPU

  ''' CPU or Cuda '''
    torch.cuda.set_device(args.local_rank)
    # device = torch.device("cuda", args.local_rank)
    device = torch.device("cpu", args.local_rank)

数据训练加载

''' DATA TRAINING LOADING '''
    logger.log_string('\nLoad Training Dataset ...')
    train_Dataset = TrainSetDataLoader(args)
    logger.log_string("The number of training data is: %d" % len(train_Dataset))
    train_loader = torch.utils.data.DataLoader(dataset=train_Dataset, num_workers=args.num_workers,
                                               batch_size=args.batch_size, shuffle=True,)

模型加载

''' MODEL LOADING '''
    logger.log_string('\nModel Initial ...')
    MODEL_PATH = 'model.' + args.model_name
    MODEL = importlib.import_module(MODEL_PATH)
    net = MODEL.get_model(args)

加载预训练模型

''' load pre-trained pth '''
    if args.use_pre_pth == False:
        net.apply(MODEL.weights_init)
        start_epoch = 0
        logger.log_string('Do not use pretrain model!')
    else:
        try:
            ckpt_path = args.path_pre_pth
            checkpoint = torch.load(ckpt_path, map_location='cpu')
            start_epoch = checkpoint['epoch']
            try:
                new_state_dict = OrderedDict()
                for k, v in checkpoint['state_dict'].items():
                    name = 'module.' + k  # add `module.`
                    new_state_dict[name] = v
                # load params
                net.load_state_dict(new_state_dict)
                logger.log_string('Use pretrain model!')
            except:
                new_state_dict = OrderedDict()
                for k, v in checkpoint['state_dict'].items():
                    new_state_dict[k] = v
                # load params
                net.load_state_dict(new_state_dict)
                logger.log_string('Use pretrain model!')
        except:
            net.apply(MODEL.weights_init)
            start_epoch = 0
            logger.log_string('No existing model, starting training from scratch...')
            pass
        pass
    net = net.to(device)
    cudnn.benchmark = True

打印参数

''' Print Parameters '''
    logger.log_string('PARAMETER ...')
    logger.log_string(args)

损失加载

'''LOSS LOADING '''
    criterion = MODEL.get_loss(args).to(device)
优化器(Adam)
optimizer = torch.optim.Adam(
        [paras for paras in net.parameters() if paras.requires_grad == True],
        lr=args.lr,
        betas=(0.9, 0.999),
        eps=1e-08,
        weight_decay=args.decay_rate
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.n_steps, gamma=args.gamma)

训练

''' TRAINING '''
    logger.log_string('\nStart training...')
    for idx_epoch in range(start_epoch, args.epoch):
        logger.log_string('\nEpoch %d /%s:' % (idx_epoch + 1, args.epoch))
        loss_epoch_train, psnr_epoch_train, ssim_epoch_train = train(train_loader, device, net, criterion, optimizer)
        logger.log_string('The %dth Train, loss is: %.5f, psnr is %.5f, ssim is %.5f' %
                          (idx_epoch + 1, loss_epoch_train, psnr_epoch_train, ssim_epoch_train))

        # save model
        if args.local_rank == 0:
            save_ckpt_path = str(checkpoints_dir) + '/%s_%dx%d_%dx_epoch_%02d_model.pth' % (
            args.model_name, args.angRes, args.angRes, args.scale_factor, idx_epoch + 1)
            state = {
                'epoch': idx_epoch + 1,
                'state_dict': net.module.state_dict() if hasattr(net, 'module') else net.state_dict(),
            }
            torch.save(state, save_ckpt_path)
            logger.log_string('Saving the epoch_%02d model at %s' % (idx_epoch + 1, save_ckpt_path))

        ''' scheduler '''
        scheduler.step()
        pass
    pass

3、定义train函数

def train(train_loader,device,net,criterion,optimizer):

训练一个epoch(1个epoch表示过了1遍训练集中的所有样本。)

'''training one epoch'''
    psnr_iter_train = []
    loss_iter_train = []
    ssim_iter_train = []
    args.temperature = 1.0
    for idx_iter, (data, label) in tqdm(enumerate(train_loader), total=len(train_loader), ncols=70):
        data = data.to(device)      # low resolution
        label = label.to(device)    # high resolution
        out = net(data)
        loss = criterion(out, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()

        loss_iter_train.append(loss.data.cpu())
        psnr, ssim = cal_metrics(args, label, out)
        psnr_iter_train.append(psnr)
        ssim_iter_train.append(ssim)
        pass

    loss_epoch_train = float(np.array(loss_iter_train).mean())
    psnr_epoch_train = float(np.array(psnr_iter_train).mean())
    ssim_epoch_train = float(np.array(ssim_iter_train).mean())

    return loss_epoch_train, psnr_epoch_train, ssim_epoch_train

4、控制代码

if __name__ == '__main__':
    from option import args

    main(args)

四、test.py

1、导入模块

from torch.utils.data import DataLoader
import importlib
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from utils.utils import *
from utils.utils_datasets import MultiTestSetDataLoader
from collections import OrderedDict

2、定义主函数:

def main(args):

为保存创建目录

    ''' Create Dir for Save'''
    experiment_dir, checkpoints_dir, log_dir = create_dir(args)

记录器:

    ''' Logger '''
    logger = Logger(log_dir, args)

设置CPU或GPU

    ''' CPU or Cuda '''
    torch.cuda.set_device(args.local_rank)
    # device = torch.device("cuda", args.local_rank)
    device = torch.device("cpu", args.local_rank)

数据训练加载

    ''' DATA TEST LOADING '''
    logger.log_string('\nLoad Test Dataset ...')
    test_Names, test_Loaders, length_of_tests = MultiTestSetDataLoader(args)
    logger.log_string("The number of test data is: %d" % length_of_tests)

模型加载

    ''' MODEL LOADING '''
    logger.log_string('\nModel Initial ...')
    MODEL_PATH = 'model.' + args.model_name
    MODEL = importlib.import_module(MODEL_PATH)
    net = MODEL.get_model(args)

 加载预训练模型

    ''' load pre-trained pth '''
    ckpt_path = args.path_pre_pth
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    start_epoch = checkpoint['epoch']
    try:
        new_state_dict = OrderedDict()
        for k, v in checkpoint['state_dict'].items():
            name = 'module.' + k  # add `module.`
            new_state_dict[name] = v
        # load params
        net.load_state_dict(new_state_dict)
        logger.log_string('Use pretrain model!')
    except:
        new_state_dict = OrderedDict()
        for k, v in checkpoint['state_dict'].items():
            new_state_dict[k] = v
        # load params
        net.load_state_dict(new_state_dict)
        logger.log_string('Use pretrain model!')

    net = net.to(device)
    cudnn.benchmark = True

测试每个数据集

    ''' TEST on every dataset'''
    logger.log_string('\nStart test...')
    with torch.no_grad():
        psnr_testset = []
        ssim_testset = []
        for index, test_name in enumerate(test_Names):
            test_loader = test_Loaders[index]

            psnr_epoch_test, ssim_epoch_test = test(test_loader, device, net)
            psnr_testset.append(psnr_epoch_test)
            ssim_testset.append(ssim_epoch_test)
            logger.log_string('Test on %s, psnr/ssim is %.2f/%.3f' % (test_name, psnr_epoch_test, ssim_epoch_test))
            pass
        pass
    pass

3、定义 test函数

def test(test_loader, device, net):
    psnr_iter_test = []
    ssim_iter_test = []
    for idx_iter, (Lr_SAI_y, Hr_SAI_y) in tqdm(enumerate(test_loader), total=len(test_loader), ncols=70):
        Lr_SAI_y = Lr_SAI_y.squeeze().to(device)  # numU, numV, h*angRes, w*angRes
        Hr_SAI_y = Hr_SAI_y.squeeze()

        uh, vw = Lr_SAI_y.shape
        h0, w0 = int(uh//args.angRes), int(vw//args.angRes)

        subLFin = LFdivide(Lr_SAI_y, args.angRes, args.patch_size_for_test, args.stride_for_test)
        numU, numV, H, W = subLFin.size()
        subLFout = torch.zeros(numU, numV, args.angRes * args.patch_size_for_test * args.scale_factor,
                               args.angRes * args.patch_size_for_test * args.scale_factor)

        for u in range(numU):
            for v in range(numV):
                tmp = subLFin[u:u+1, v:v+1, :, :]
                with torch.no_grad():
                    net.eval()
                    torch.cuda.empty_cache()
                    out = net(tmp.to(device))
                    subLFout[u:u+1, v:v+1, :, :] = out.squeeze()

        Sr_4D_y = LFintegrate(subLFout, args.angRes, args.patch_size_for_test * args.scale_factor,
                              args.stride_for_test * args.scale_factor, h0 * args.scale_factor,
                              w0 * args.scale_factor)
        Sr_SAI_y = Sr_4D_y.permute(0, 2, 1, 3).reshape((h0 * args.angRes * args.scale_factor,
                                                        w0 * args.angRes * args.scale_factor))

        psnr, ssim = cal_metrics(args, Hr_SAI_y, Sr_SAI_y)
        psnr_iter_test.append(psnr)
        ssim_iter_test.append(ssim)
        pass

    psnr_epoch_test = float(np.array(psnr_iter_test).mean())
    ssim_epoch_test = float(np.array(ssim_iter_test).mean())

    return psnr_epoch_test, ssim_epoch_test

5、控制代码


if __name__ == '__main__':
    from option import args

    main(args)

五、LTF.py    module 

1、导入模块

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math

2、定义get_module类:

class get_model(nn.Module):

初始化

实例属性:(通道数、缩放因子、层数=4、位置编码、MHSA。。。)、初始卷积、Alternate AngTrans & SpaTrans ???、上采样

    def __init__(self, args):
        super(get_model, self).__init__()
        channels = args.channels
        self.channels = channels
        self.angRes = args.angRes
        self.factor = args.scale_factor
        layer_num = 4

        self.pos_encoding = PositionEncoding(temperature=10000)
        self.MHSA_params = {}
        self.MHSA_params['num_heads'] = 8
        self.MHSA_params['dropout'] = 0.

        ##################### Initial Convolution #####################
        self.conv_init0 = nn.Sequential(
            nn.Conv3d(1, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1), dilation=1, bias=False),
        )
        self.conv_init = nn.Sequential(
            nn.Conv3d(channels, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1), dilation=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(channels, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1), dilation=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(channels, channels, kernel_size=(1, 3, 3), padding=(0, 1, 1), dilation=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )

        ################ Alternate AngTrans & SpaTrans ################
        self.altblock = self.make_layer(layer_num=layer_num)

        ####################### UP Sampling ###########################
        self.upsampling = nn.Sequential(
            nn.Conv2d(channels, channels*self.factor ** 2, kernel_size=1, padding=0, dilation=1, bias=False),
            nn.PixelShuffle(self.factor),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=False),
        )

 make_layer函数

    def make_layer(self, layer_num):
        layers = []
        for i in range(layer_num):
            layers.append(AltFilter(self.angRes, self.channels, self.MHSA_params))
        return nn.Sequential(*layers)

forward函数:

双三次插值、reshape、初始卷积、位置编码、Alternate AngTrans&SpaTrans、上采样。

    def forward(self, lr):
        # Bicubic
        lr_upscale = interpolate(lr, self.angRes, scale_factor=self.factor, mode='bicubic')
        # [B(atch), 1, A(ngRes)*h(eight)*S(cale), A(ngRes)*w(idth)*S(cale)]

        # reshape for LFT
        lr = rearrange(lr, 'b c (a1 h) (a2 w) -> b c (a1 a2) h w', a1=self.angRes, a2=self.angRes)
        # [B, C(hannels), A^2, h, w]
        for m in self.modules():
            m.h = lr.size(-2)
            m.w = lr.size(-1)

        # Initial Convolution
        buffer = self.conv_init0(lr)
        buffer = self.conv_init(buffer) + buffer  # [B, C, A^2, h, w]

        # Position Encoding
        spa_position = self.pos_encoding(buffer, dim=[3, 4], token_dim=self.channels)
        ang_position = self.pos_encoding(buffer, dim=[2], token_dim=self.channels)
        for m in self.modules():
            m.spa_position = spa_position
            m.ang_position = ang_position

        # Alternate AngTrans & SpaTrans
        buffer = self.altblock(buffer) + buffer

        # Up-Sampling
        buffer = rearrange(buffer, 'b c (a1 a2) h w -> b c (a1 h) (a2 w)', a1=self.angRes, a2=self.angRes)
        buffer = self.upsampling(buffer)
        out = buffer + lr_upscale

        return out

 #上下采样函数interpolate()   Pytorch上下采样函数--interpolate()_Activewaste的博客-CSDN博客_interpolate()

#rearrange()     einops库中rearrange,reduce和repeat的介绍_鬼道2022的博客-CSDN博客_einops.rearrange 

 3、位置编码类:

class PositionEncoding(nn.Module):

初始化实例属性

    def __init__(self, temperature):
        super(PositionEncoding, self).__init__()
        self.temperature = temperature

 forward函数

    def forward(self, x, dim: list, token_dim):
        self.token_dim = token_dim
        assert len(x.size()) == 5, 'the object of position encoding requires 5-dim tensor! '
        grid_dim = torch.linspace(0, self.token_dim - 1, self.token_dim, dtype=torch.float32)
        grid_dim = 2 * (grid_dim // 2) / self.token_dim
        grid_dim = self.temperature ** grid_dim
        position = None
        for index in range(len(dim)):
            pos_size = [1, 1, 1, 1, 1, self.token_dim]
            length = x.size(dim[index])
            pos_size[dim[index]] = length

            pos_dim = (torch.linspace(0, length - 1, length, dtype=torch.float32).view(-1, 1) / grid_dim).to(x.device)
            pos_dim = torch.cat([pos_dim[:, 0::2].sin(), pos_dim[:, 1::2].cos()], dim=1)
            pos_dim = pos_dim.view(pos_size)

            if position is None:
                position = pos_dim
            else:
                position = position + pos_dim
            pass

        position = rearrange(position, 'b 1 a h w dim -> b dim a h w')

        return position / len(dim)

torch.linspace()    torch.linspace()用法_快乐地笑的博客-CSDN博客 

 4、定义SpanTrans类:

class SpaTrans(nn.Module):

 初始化实例属性

    def __init__(self, channels, angRes, MHSA_params):
        super(SpaTrans, self).__init__()
        self.angRes = angRes
        self.kernel_field = 3
        self.kernel_search = 5
        self.spa_dim = channels * 2
        self.MLP = nn.Linear(channels * self.kernel_field ** 2, self.spa_dim, bias=False)

        self.norm = nn.LayerNorm(self.spa_dim)
        self.attention = nn.MultiheadAttention(self.spa_dim,
                                               MHSA_params['num_heads'],
                                               MHSA_params['dropout'],
                                               bias=False)
        nn.init.kaiming_uniform_(self.attention.in_proj_weight, a=math.sqrt(5))
        self.attention.out_proj.bias = None

        self.feed_forward = nn.Sequential(
            nn.LayerNorm(self.spa_dim),
            nn.Linear(self.spa_dim, self.spa_dim*2, bias=False),
            nn.ReLU(True),
            nn.Dropout(MHSA_params['dropout']),
            nn.Linear(self.spa_dim*2, self.spa_dim, bias=False),
            nn.Dropout(MHSA_params['dropout'])
        )
        self.linear = nn.Sequential(
            nn.Conv3d(self.spa_dim, channels, kernel_size=(1, 1, 1), padding=(0, 0, 0), dilation=1, bias=False),
        )

 nn.Linear()   nn.Linear()函数详解及代码使用_墨晓白的博客-CSDN博客_nn.linear

 nn.LayerNorm()   nn.LayerNorm的实现及原理_harry_tea的博客-CSDN博客_layer norm 

定义gen_mask函数,并使用 @staticmethod

    @staticmethod
    def gen_mask(h:int, w:int, k:int):
        atten_mask = torch.zeros([h, w, h, w])
        k_left = k//2
        k_right = k - k_left
        for i in range(h):
            for j in range(w):
                temp = torch.zeros(h, w)
                temp[max(0, i-k_left):min(h,i+k_right), max(0, j-k_left):min(h,j+k_right)] = 1
                atten_mask[i, j, :, :] = temp

        atten_mask = rearrange(atten_mask, 'a b c d -> (a b) (c d)')
        atten_mask = atten_mask.float().masked_fill(atten_mask == 0, float('-inf')).\
            masked_fill(atten_mask == 1, float(0.0))

        return atten_mask

定义SAI2Token函数

    def SAI2Token(self, buffer):
        buffer = rearrange(buffer, 'b c a h w -> (b a) c h w')
        # local feature embedding
        spa_token = F.unfold(buffer, kernel_size=self.kernel_field, padding=self.kernel_field//2).permute(2, 0, 1)
        spa_token = self.MLP(spa_token)
        return spa_token

定义Token2SAI函数

    def Token2SAI(self, buffer_token_spa):
        buffer = rearrange(buffer_token_spa, '(h w) (b a) c -> b c a h w', h=self.h, w=self.w, a=self.angRes**2)
        buffer = self.linear(buffer)
        return buffer

 定义foward函数

    def forward(self, buffer):
        ang_token = self.SAI2Token(buffer)
        ang_PE = self.SAI2Token(self.ang_position)
        ang_token_norm = self.norm(ang_token + ang_PE)

        ang_token = self.attention(query=ang_token_norm,
                                   key=ang_token_norm,
                                   value=ang_token,
                                   need_weights=False)[0] + ang_token

        ang_token = self.feed_forward(ang_token) + ang_token
        buffer = self.Token2SAI(ang_token)

        return buffer

5、定义AltFilter类

class AltFilter(nn.Module):
    def __init__(self, angRes, channels, MHSA_params):
        super(AltFilter, self).__init__()
        self.angRes = angRes
        self.spa_trans = SpaTrans(channels, angRes, MHSA_params)
        self.ang_trans = AngTrans(channels, angRes, MHSA_params)

    def forward(self, buffer):
        buffer = self.ang_trans(buffer)
        buffer = self.spa_trans(buffer)

        return buffer

定义interpolate函数

def interpolate(x, angRes, scale_factor, mode):
    [B, _, H, W] = x.size()
    h = H // angRes
    w = W // angRes
    x_upscale = x.view(B, 1, angRes, h, angRes, w)
    x_upscale = x_upscale.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * angRes ** 2, 1, h, w)
    x_upscale = F.interpolate(x_upscale, scale_factor=scale_factor, mode=mode, align_corners=False)
    x_upscale = x_upscale.view(B, angRes, angRes, 1, h * scale_factor, w * scale_factor)
    x_upscale = x_upscale.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, 1, H * scale_factor, W * scale_factor)
    # [B, 1, A*h*S, A*w*S]

    return x_upscale

6、定义get_loss函数 

class get_loss(nn.Module):
    def __init__(self, args):
        super(get_loss, self).__init__()
        self.criterion_Loss = torch.nn.L1Loss()

    def forward(self, SR, HR):
        loss = self.criterion_Loss(SR, HR)

        return loss

7、weight_init函数

def weights_init(m):

    pass

 

 

 

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值