今天我们来介绍一个神奇的网络,生成对抗网络GAN,这个模型纯属当做娱乐,供大家消遣娱乐,在这里我只展示一下GAN模型有趣的一个小功能,先来给大家介绍一下GAN模型吧。
GAN 的基本原理
GAN(Generative Adversarial Networks,生成对抗网络)是一种深度学习模型,由两个主要的部分组成:生成器和判别器。生成器试图从一个简单的随机噪声分布中生成数据实例,而判别器则尝试区分这些生成的样本和真实的样本 1。
生成器和判别器的对抗训练
在 GAN 的训练过程中,生成器和判别器进行着一种零和游戏的对抗训练。生成器试图欺骗判别器,使其认为生成的样本是真实的,而判别器则试图正确地区分出哪些样本是真实的,哪些是生成的 1。
权重共享和批标准化
在训练过程中,为了解决模式塌陷问题,GAN 使用了权重共享和批标准化技术。权重共享意味着每个样本使用相同的权重进行生成,而批标准化则确保了每个批次的样本具有相同的均值和方差 1。
GAN 的主要应用场景
GAN 由于其生成高质量数据的能力
但是我们今天使用到的是GAN网络中的其中的一个小的分支CycleGAN
一、CycleGAN 简介
CycleGAN(Cycle Generative Adversarial Network)是一种特殊类型的生成对抗网络,旨在解决无配对数据的图像到图像转换问题。
二、工作原理
- 两个生成器
- 一个将源域图像转换为目标域图像。
- 另一个执行相反的转换。
- 两个判别器
- 分别判断生成的目标域图像和源域图像的真实性。
- 循环一致性损失
- 确保转换后的图像能够再转换回原始图像,保持一定的相似性。
三、特点
- 无需配对数据
- 传统的图像转换方法通常需要源域和目标域一一对应的图像对,而 CycleGAN 打破了这一限制。
- 多领域转换
- 能够实现多种不同领域之间的图像转换,如风格迁移、季节转换等。
- 灵活性高
- 可以根据不同的任务和数据进行调整和优化。
四、应用领域
- 艺术创作
- 实现不同艺术风格之间的转换。
- 图像增强
- 改善图像的质量和效果。
- 虚拟现实和增强现实
- 生成逼真的虚拟场景和增强现实效果。
五、结论
CycleGAN 为图像转换任务提供了一种创新且有效的方法,具有广泛的应用前景和研究价值。
CycleGAN是循环生成式对抗网络的缩写,它可以将一个领域的图像转换为另一个领域的图像。
本篇博客也是利用了这个特点。
一个把马变成斑马的网络,CycleGAN对从IamgeNet数据集中提取的(不相关的)马和斑马的数据集进行了训练。该网络学习获取一匹或多匹马的图像,并将他们全部变成斑马,图像的其余部分尽可能不被修改。代码直接给大家了,我就不进行代码的讲解了。
from PIL import Image
from torchvision import transforms
import torch
import torch
import torch.nn as nn
class ResNetBlock(nn.Module): # <1>
def __init__(self, dim):
super(ResNetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim)
def build_conv_block(self, dim):
conv_block = []
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim),
nn.ReLU(True)]
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x) # <2>
return out
class ResNetGenerator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3>
assert(n_blocks >= 0)
super(ResNetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
nn.InstanceNorm2d(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=True),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResNetBlock(ngf * mult)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=True),
nn.InstanceNorm2d(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input): # <3>
return self.model(input)
netG = ResNetGenerator()
model_path = "C:\\deep learning\\pytorch学习\\horse2zebra_0.4.0.pth"
model_data = torch.load(model_path)
netG.load_state_dict(model_data)
netG.eval()
preprocess = transforms.Compose([transforms.Resize(256),
transforms.ToTensor()])
img = Image.open("C:\\deep learning\\pytorch学习\\horse.jpg")
img.show()
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
batch_out = netG(batch_t)
out_t = (batch_out.data.squeeze() + 1.0) / 2
out_img = transforms.ToPILImage()(out_t)
out_img.save("C:\\deep learning\\pytorch学习\\horse2.jpg")
out_img.show()
上面的ResNetGenerator类将其方便后续的实例化,把路径改为自己的就行,需要权重文件以及horse的图片的评论区留言,看到就发网盘链接
代码调试完成后,运行代码后的效果如下图所示
是不是特别神奇,有趣!!!