CycleGAN代码解析(1)

代码来源:PyTorch-GAN/implementations/cyclegan at master · eriklindernoren/PyTorch-GAN · GitHub

在此分析文件cyclegan.py

一.cyclegan.py

parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="monet2photo", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints")
parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")
parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight")
parser.add_argument("--lambda_id", type=float, default=5.0, help="identity loss weight")
opt = parser.parse_args()
print(opt)

这段代码使用了Python的`argparse`模块来定义和解析命令行参数,主要用于训练机器学习模型。以下是每个参数的中文解释:

1. `--epoch`:从哪个epoch(训练轮次)开始训练。这个参数在从检查点恢复训练时非常有用。默认值是`0`。
2. `--n_epochs`:训练模型的总epoch数。默认值是`200`。
3. `--dataset_name`:使用的数据集的名称。默认值是`"monet2photo"`。
4. `--batch_size`:数据批次的大小。默认值是`1`。
5. `--lr`:Adam优化器的学习率。默认值是`0.0002`。
6. `--b1`:Adam优化器一阶动量的衰减率。默认值是`0.5`。
7. `--b2`:Adam优化器二阶动量的衰减率。默认值是`0.999`。
8. `--decay_epoch`:从哪个epoch开始衰减学习率。默认值是`100`。
9. `--n_cpu`:用于生成批次数据的CPU线程数。默认值是`8`。
10. `--img_height`:输入图像的高度。默认值是`256`。
11. `--img_width`:输入图像的宽度。默认值是`256`。
12. `--channels`:图像的通道数(例如,RGB图像的通道数为`3`)。默认值是`3`。
13. `--sample_interval`:保存生成器输出的间隔(以批次为单位)。默认值是`100`。
14. `--checkpoint_interval`:保存模型检查点的间隔(以epoch为单位)。默认值是`-1`(表示不保存检查点)。
15. `--n_residual_blocks`:生成器模型中的残差块数量。默认值是`9`。
16. `--lambda_cyc`:循环一致性损失的权重。默认值是`10.0`。
17. `--lambda_id`:身份损失的权重。默认值是`5.0`。

执行这段代码时,它会解析命令行参数,并将解析后的参数存储在`opt`变量中。最后,这些参数会被打印出来。

例如,可以通过命令行运行这个脚本并传入不同的参数:
python script.py --epoch 10 --n_epochs 300 --dataset_name "vangogh2photo"

这会将`epoch`设置为`10`,`n_epochs`设置为`300`,`dataset_name`设置为`"vangogh2photo"`,而其他参数将使用默认值。 


# Create sample and checkpoint directories
os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)

# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

cuda = torch.cuda.is_available()

input_shape = (opt.channels, opt.img_height, opt.img_width)

# Initialize generator and discriminator
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

if opt.epoch != 0:
    # Load pretrained models
    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
else:
    # Initialize weights
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

 这段代码是用于训练生成对抗网络(GAN)的初始化步骤,具体是针对CycleGAN模型。下面是逐步解释每个代码块的功能:

1. **创建样本和检查点目录**:

os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)

这段代码创建了两个目录,一个用于保存生成的图像样本,另一个用于保存训练好的模型权重。目录名称由数据集的名称决定。

2. **定义损失函数**:

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

这里定义了三个损失函数:
    - `criterion_GAN`: 用于判别器的对抗性损失,使用均方误差损失(MSELoss)。
    - `criterion_cycle`: 用于循环一致性损失,使用L1损失(即平均绝对误差)。
    - `criterion_identity`: 用于身份损失,也使用L1损失。

3. **检查CUDA是否可用**:

cuda = torch.cuda.is_available()

这段代码检查是否可以使用GPU进行训练。

4. **定义输入图像的形状**:

input_shape = (opt.channels, opt.img_height, opt.img_width)

这里定义了输入图像的形状,包括通道数、高度和宽度。

5. **初始化生成器和判别器**:

G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

这里初始化了两个生成器(G_AB和G_BA)和两个判别器(D_A和D_B)。生成器使用了残差网络(ResNet),而判别器是一个标准的二分类网络。

6. **将模型和损失函数移动到GPU(如果可用)**:

 if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

如果GPU可用,将模型和损失函数移动到GPU上,以加速训练过程。

7. **加载预训练模型或初始化模型权重**:

if opt.epoch != 0:
   # Load pretrained models
        G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))
        G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))
        D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))
        D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
else:
   # Initialize weights
        G_AB.apply(weights_init_normal)
        G_BA.apply(weights_init_normal)
        D_A.apply(weights_init_normal)
        D_B.apply(weights_init_normal)

如果从某个特定的epoch开始训练,则加载预训练的模型权重;否则,初始化模型的权重。

这个初始化步骤确保了模型准备好开始训练,并根据是否从检查点恢复或从头开始训练采取相应的措施。


# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

这段代码继续初始化CycleGAN训练所需的其他组件,包括优化器、学习率调度器、数据类型定义以及样本缓冲区。下面是对每个部分的详细解释:

1. **优化器**:

optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

这里定义了三个Adam优化器:
    - `optimizer_G`:用于优化生成器G_AB和G_BA的参数。`itertools.chain`用于将两个生成器的参数链在一起,以便同时更新。
    - `optimizer_D_A`:用于优化判别器D_A的参数。
    - `optimizer_D_B`:用于优化判别器D_B的参数。

    这些优化器的学习率和动量衰减率通过命令行参数`opt.lr`、`opt.b1`和`opt.b2`进行配置。

2. **学习率更新调度器**:

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)

这些调度器用于在训练过程中更新学习率。`LambdaLR`使用了一个自定义的`LambdaLR`类,该类定义了一个用于调整学习率的函数`step`。这个函数根据总训练epoch数、当前epoch数和开始衰减学习率的epoch数来动态调整学习率。

3. **定义张量数据类型**:

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

这里根据是否使用CUDA定义了张量的数据类型。如果使用CUDA,则使用`torch.cuda.FloatTensor`,否则使用`torch.Tensor`。

4. **生成的假样本缓冲区**:

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

这些缓冲区用于存储生成的假样本。重播缓冲区(ReplayBuffer)可以帮助平滑训练过程,防止模型振荡。具体来说,生成器生成的假样本会存储在缓冲区中,并在训练过程中从缓冲区随机抽取样本用于更新判别器,从而提高训练的稳定性。

整体来说,这段代码完成了CycleGAN训练的必要初始化步骤,确保模型、优化器、学习率调度器和缓冲区都已正确配置,为后续的训练过程做好准备。


# Image transformations
transforms_ = [
    transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
    transforms.RandomCrop((opt.img_height, opt.img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

# Training data loader
dataloader = DataLoader(
    ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.n_cpu,
)
# Test data loader
val_dataloader = DataLoader(
    ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"),
    batch_size=5,
    shuffle=True,
    num_workers=1,
)

这段代码设置了图像预处理步骤,并创建了训练和测试数据加载器。以下是详细解释:

1. **图像变换(transformations)**:

transforms_ = [
        transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
        transforms.RandomCrop((opt.img_height, opt.img_width)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

这些图像变换用于预处理数据集中的图像,以便更好地进行训练:
    - `transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC)`:将图像按双三次插值法调整到原高度的1.12倍。
    - `transforms.RandomCrop((opt.img_height, opt.img_width))`:随机裁剪图像到指定的高度和宽度。
    - `transforms.RandomHorizontalFlip()`:随机水平翻转图像。
    - `transforms.ToTensor()`:将图像转换为PyTorch张量,并将像素值从[0, 255]缩放到[0, 1]。
    - `transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))`:对图像进行标准化处理,将每个通道的像素值从[0, 1]缩放到[-1, 1]。

2. **训练数据加载器**:

dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
    )

这段代码创建了一个用于训练的数据加载器:
    - `ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True)`:实例化一个自定义的数据集类`ImageDataset`,使用指定的图像变换,并设置`unaligned=True`表示使用未对齐的图像对。
    - `batch_size=opt.batch_size`:设置每个批次的大小。
    - `shuffle=True`:在每个epoch开始时打乱数据。
    - `num_workers=opt.n_cpu`:用于加载数据的工作进程数。

3. **测试数据加载器**:

val_dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"),
        batch_size=5,
        shuffle=True,
        num_workers=1,
    )

这段代码创建了一个用于测试的数据加载器:
    - `ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test")`:同样实例化`ImageDataset`类,但将`mode`设置为`"test"`,表示加载测试数据。
    - `batch_size=5`:设置测试批次的大小为5。
    - `shuffle=True`:打乱测试数据。
    - `num_workers=1`:设置加载数据的工作进程数为1。

通过这些代码,数据预处理和数据加载过程得到了设置,准备好进行后续的训练和验证步骤。


def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    # Arange images along x-axis
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)


# ----------
#  Training
# ----------

prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):

        # Set model input
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        G_AB.train()
        G_BA.train()

        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

        loss_G.backward()
        optimizer_G.step()

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        optimizer_D_A.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        optimizer_D_B.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = opt.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
            % (
                epoch,
                opt.n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_GAN.item(),
                loss_cycle.item(),
                loss_identity.item(),
                time_left,
            )
        )

        # If at sample interval save image
        if batches_done % opt.sample_interval == 0:
            sample_images(batches_done)

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
        torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
        torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
        torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))

这段代码实现了CycleGAN模型的训练过程,包含生成器和判别器的训练步骤,以及图像保存和模型检查点的功能。下面是详细的解释:

### 定义生成样本并保存图像的函数 ###

def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    # Arrange images along x-axis
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    # Arrange images along y-axis
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)

这个函数在测试集上生成样本图像并保存。它首先从验证数据加载器中取出一批图像,然后让生成器生成对应的假图像。最后,将真实和生成的图像排成网格并保存。

### 训练过程 ###

prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):

        # Set model input
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

        # Train Generators
        G_AB.train()
        G_BA.train()

        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)
        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
        loss_G.backward()
        optimizer_G.step()

        # Train Discriminator A
        optimizer_D_A.zero_grad()
        loss_real = criterion_GAN(D_A(real_A), valid)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        loss_D_A = (loss_real + loss_fake) / 2
        loss_D_A.backward()
        optimizer_D_A.step()

        # Train Discriminator B
        optimizer_D_B.zero_grad()
        loss_real = criterion_GAN(D_B(real_B), valid)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        loss_D_B = (loss_real + loss_fake) / 2
        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        # Log Progress
        batches_done = epoch * len(dataloader) + i
        batches_left = opt.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
            % (
                epoch,
                opt.n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_GAN.item(),
                loss_cycle.item(),
                loss_identity.item(),
                time_left,
            )
        )

        # If at sample interval save image
        if batches_done % opt.sample_interval == 0:
            sample_images(batches_done)

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
        torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
        torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
        torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))

#### 主要步骤解释 ####:

1. **初始化**:
    - 设置当前时间`prev_time`用于计算剩余时间。
    - 开始epoch循环,从指定的`opt.epoch`到`opt.n_epochs`。

2. **批次训练**:
    - 从数据加载器中读取一个批次的图像,包括`real_A`和`real_B`。
    - 创建判别器的真实和假的标签。
  
3. **训练生成器**:
    - 将生成器置于训练模式。
    - 计算身份损失(Identity Loss):生成器在给定输入图像的情况下应输出与输入图像尽可能相似的图像。【例如:给G_AB图像B,它应该输出一个尽可能和B一样的图像】
    - 计算对抗损失(GAN Loss):生成器生成的图像应该尽可能使判别器认为它们是真实的。
    - 计算循环一致损失(Cycle Loss):通过生成器生成的图像应该能够恢复到原始输入图像。【例如:realA通过G_AB生成fakeB,fakeB通过B_BA还原出realA】
    - 总的生成器损失是对抗损失、循环一致损失和身份损失的加权和。通过反向传播更新生成器参数。

4. **训练判别器**:
    - 分别训练两个判别器D_A和D_B。
    - 对每个判别器,计算真实图像的损失和生成的假图像的损失。
    - 更新判别器参数。

5. **日志记录和保存图像**:
    - 计算和显示损失值以及剩余时间。
    - 在每个指定的间隔保存生成的样本图像。

6. **更新学习率**:
    - 每个epoch结束后更新生成器和判别器的学习率。

7. **保存模型检查点**:
    - 根据指定的检查点间隔保存模型的状态字典。

这个循环一直持续到所有的epoch完成。这个训练过程旨在逐步优化CycleGAN模型,使其能够在两个不同域之间进行高质量的图像转换。

  • 9
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值