basic_sr介绍

pytorch基础知识和basicSR中用到的语法

1.Sampler类与4种采样方式

一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
pytorch源码阅读(三)Sampler类与4种采样方式

下面代码是自定义的一个采样器:
ratio控制扩充数据集的倍数
num_replicas是进程数,一般是world_size
rank: 当前进程的rank

其实目的就是把数据集的索引划分为num_replicas组,供每个进程(process) 处理
至于ratio,是为了使每个epoch训练的数据增多,for saving time when restart the dataloader after each epoch

import math
import torch
from torch.utils.data.sampler import Sampler


class EnlargedSampler(Sampler):
    """Sampler that restricts data loading to a subset of the dataset.

    Modified from torch.utils.data.distributed.DistributedSampler
    Support enlarging the dataset for iteration-based training, for saving
    time when restart the dataloader after each epoch

    Args:
        dataset (torch.utils.data.Dataset): Dataset used for sampling.
        num_replicas (int | None): Number of processes participating in
            the training. It is usually the world_size.
        rank (int | None): Rank of the current process within num_replicas.
        ratio (int): Enlarging ratio. Default: 1.
    """

    def __init__(self, dataset, num_replicas, rank, ratio=1):
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()

        dataset_size = len(self.dataset)
        indices = [v % dataset_size for v in indices]

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch

测试一下:

import numpy as np
if __name__ == "__main__":
    data = np.arange(20).tolist()
    en_sample = EnlargedSampler(data, 2, 0)
    en_sample.set_epoch(1)
    for i in en_sample:
        print(i)
    print('\n------------------\n')
    en_sample = EnlargedSampler(data, 2, 1)
    en_sample.set_epoch(1) # 设置为同一个epoch .  rank=0或者1时生成的index是互补的

    # 或者不用设置,默认为0即可。
    for i in en_sample:
        print(i)

结果:
在这里插入图片描述

2.python dict的get方法使用

在这里插入图片描述

3.prefetch_dataloader.py

在这里插入图片描述

载入本批数据的时候,预先载入下一批数据。主要看next函数

import queue as Queue
import threading
import torch
from torch.utils.data import DataLoader


class PrefetchGenerator(threading.Thread):
    """A general prefetch generator.

    Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch

    Args:
        generator: Python generator.
        num_prefetch_queue (int): Number of prefetch queue.
    """

    def __init__(self, generator, num_prefetch_queue):
        threading.Thread.__init__(self)
        self.queue = Queue.Queue(num_prefetch_queue)
        self.generator = generator
        self.daemon = True
        self.start()

    def run(self):
        for item in self.generator:
            self.queue.put(item)
        self.queue.put(None)

    def __next__(self):
        next_item = self.queue.get()
        if next_item is None:
            raise StopIteration
        return next_item

    def __iter__(self):
        return self


class PrefetchDataLoader(DataLoader):
    """Prefetch version of dataloader.

    Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#

    TODO:
    Need to test on single gpu and ddp (multi-gpu). There is a known issue in
    ddp.

    Args:
        num_prefetch_queue (int): Number of prefetch queue.
        kwargs (dict): Other arguments for dataloader.
    """

    def __init__(self, num_prefetch_queue, **kwargs):
        self.num_prefetch_queue = num_prefetch_queue
        super(PrefetchDataLoader, self).__init__(**kwargs)

    def __iter__(self):
        return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)


class CPUPrefetcher():
    """CPU prefetcher.

    Args:
        loader: Dataloader.
    """

    def __init__(self, loader):
        self.ori_loader = loader
        self.loader = iter(loader)

    def next(self):
        try:
            return next(self.loader)
        except StopIteration:
            return None

    def reset(self):
        self.loader = iter(self.ori_loader)


class CUDAPrefetcher():
    """CUDA prefetcher.

    Reference: https://github.com/NVIDIA/apex/issues/304#

    It may consume more GPU memory.

    Args:
        loader: Dataloader.
        opt (dict): Options.
    """

    def __init__(self, loader, opt):
        self.ori_loader = loader
        self.loader = iter(loader)
        self.opt = opt
        self.stream = torch.cuda.Stream()
        self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
        self.preload()

    def preload(self):
        try:
            self.batch = next(self.loader)  # self.batch is a dict
        except StopIteration:
            self.batch = None
            return None
        # put tensors to gpu
        with torch.cuda.stream(self.stream):
            for k, v in self.batch.items():
                if torch.is_tensor(v):
                    self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream) # 等待下一批处理完毕
        batch = self.batch # 赋值
        self.preload()     # 预先载入下一批
        return batch

    def reset(self):
        self.loader = iter(self.ori_loader)
        self.preload()

4. pytorch 并行和分布式训练

4.1 选择要使用的cuda

当我们的服务器上有多个GPU,我们应该指明我们使用的GPU是哪一块,如果我们不设置的话,tensor.cuda()方法会默认将tensor保存到第一块GPU上,等价于tensor.cuda(0),这将会导致爆出out of memory的错误。我们可以通过以下两种方式继续设置。

  1. 在文件最开始部分
    #设置在文件最开始部分
    import os
    os.environ["CUDA_VISIBLE_DEVICE"] = "0,1,2" # 设置默认的显卡
    
  2. 在命令行运行的时候设置
     CUDA_VISBLE_DEVICE=0,1 python train.py # 使用0,1两块GPU
    

4.2 DataParallel使用方法

常规使用方法
   model = UNetSeeInDark()
   model._initialize_weights()

   gpus = [0123]
   model = nn.DataParallel(model, device_ids=gpus)
   device = torch.device('cuda:0')
   model = model.to(device)
   # 如果不使用并行,只需要注释掉 model = nn.DataParallel(model, device_ids=gpus)
   # 如果要更改要使用的gpu, 更改gpus,和device中的torch.device('cuda:0')中的number即可
保存和载入

保存可以使用

# 因为model被DP wrap了,得先取出模型
save_model_path = os.path.join(save_model_dir, f'checkpoint_{epoch:05d}.pth')
# torch.save(model.state_dict(), save_model_path)
torch.save(model.module.state_dict(), save_model_path)

然后载入模型:

model_copy.load_state_dict(torch.load(m_path, map_location=device))

如果没有提出model.module进行保存
在载入的时候可能需要如下方式:

checkpoint = torch.load(m_path)
model_copy.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint.items()})

4.3 DistributedDataParallel

首先DataParallel是单进程多线程的方法,并且仅能工作在单机多卡的情况。而DistributedDataParallel方法是多进程,多线程的,并且适用与单机多卡和多机多卡的情况。即使在在单机多卡的情况下DistributedDataParallell也比DataParallel的速度更快。
目前还未深入理解:
深入理解Pytorch中的分布式训练
pytorch分布式训练
Pytorch中多GPU并行计算教程
PyTorch 并行训练极简 Demo

5.wangdb 入门

直接参看:https://docs.wandb.ai/quickstart
最详细的介绍和入门

5.1 sign up(https://wandb.ai/site)

在这里插入图片描述

5.2 安装和login

pip install wandb
wandb.login() 然后复制API key

5.3 demo

import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",

    # track hyperparameters and run metadata
    config={
        "learning_rate": 0.02,
        "architecture": "CNN",
        "dataset": "CIFAR-100",
        "epochs": 10,
    }
)

# simulate training
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
    acc = 1 - 2 ** -epoch - random.random() / epoch - offset
    loss = 2 ** -epoch + random.random() / epoch + offset

    # log metrics to wandb
    wandb.log({"acc": acc, "loss": loss})

# [optional] finish the wandb run, necessary in notebooks5b1bb8a27da51a7375b4b52c24a82fe1807877f1
wandb.finish()

运行之后:

wandb: Currently logged in as: wangty537. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.15.10
wandb: Run data is saved locally in D:\code\denoise\noise-synthesis-main\wandb\run-20230921_103737-j9ezjcqo
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run wobbly-jazz-1
wandb:  View project at https://wandb.ai/wangty537/my-awesome-project
wandb:  View run at https://wandb.ai/wangty537/my-awesome-project/runs/j9ezjcqo
wandb: Waiting for W&B process to finish... (success).
wandb: 
wandb: Run history:
wandb:  acc ▁▆▇██▇▇█
wandb: loss █▄█▁▅▁▄▁
wandb: 
wandb: Run summary:
wandb:  acc 0.88762
wandb: loss 0.12236
wandb: 
wandb:  View run wobbly-jazz-1 at: https://wandb.ai/wangty537/my-awesome-project/runs/j9ezjcqo
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: .\wandb\run-20230921_103737-j9ezjcqo\logs

然后可以在 https://wandb.ai/home 查看相关信息
在这里插入图片描述

https://docs.wandb.ai/quickstart 还介绍了更多高阶应用。

5.model and train

5.1 create model

利用注册机制

# create model
model = build_model(opt)
def build_model(opt):
    """Build model from options.

    Args:
        opt (dict): Configuration. It must contain:
            model_type (str): Model type.
    """
    opt = deepcopy(opt)
    model = MODEL_REGISTRY.get(opt['model_type'])(opt)
    logger = get_root_logger()
    logger.info(f'Model [{model.__class__.__name__}] is created.')
    return model

5.2 opt中设置

model_type: SRModel
scale: 2

5.2 SRModel 类

BaseModel是基类

@MODEL_REGISTRY.register()
class SRModel(BaseModel):
    xxx
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
A clocked SR latch is a type of flip-flop that is used to store a single bit of information and is synchronized by a clock signal. Unlike the basic SR latch, the clocked SR latch has an additional input, the clock input, which controls when the inputs S and R are allowed to affect the output Q. The circuit diagram of a clocked SR latch is shown below: ``` _____ S ----| | | | | |-----Q | | | | |____| | | _____ | | | | | | |------| |-----Q' | |_____| | | CLK --|_______| ``` In the diagram, S and R represent the set and reset inputs, respectively, and Q and Q' represent the complementary outputs of the latch. The clock input is denoted by CLK. When the clock input is high, the latch is said to be "transparent," because any changes to the S and R inputs will immediately affect the output Q. When the clock input is low, the latch is said to be "latched," because the output Q is "latched" or held at its current state until the clock input goes high again. Here's how the clocked SR latch works: 1. When the clock input is high, the latch is transparent, and the outputs Q and Q' will follow the inputs S and R, respectively. 2. When the clock input goes low, the outputs Q and Q' are latched at their current state. 3. While the clock input is low, any changes to the inputs S and R will not affect the outputs Q and Q'. Instead, the changes will be "stored" in the circuit, and will only affect the outputs when the clock input goes high again. 4. When the clock input goes high again, the latch becomes transparent once more, and the outputs Q and Q' will immediately follow any changes to the inputs S and R. Overall, the clocked SR latch is a useful circuit for storing a single bit of information in a synchronized manner, and is commonly used in digital circuits such as microprocessors and memory systems.

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值