d2l风格迁移--包含tensor与pil图片互换操作

对13章风格迁移任务进行讲解,并对其中的部分操作如pil与tensor呼唤进行具体介绍,方便后续调用!

目录

1.读取原图片:

2.转换函数

3.抽取图像特征与net构件

4.定义损失

5.训练


1.读取原图片:

d2l.set_figsize()
content_img = d2l.Image.open('img/rainier.jpg')
d2l.plt.imshow(content_img);

style_img = d2l.Image.open('img/autumn-oak.jpg')
d2l.plt.imshow(style_img);

2.转换函数

rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])
def preprocess(img, image_shape):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
    return transforms(img).unsqueeze(0)

def postprocess(img):
    img = img[0].to(rgb_std.device)  # 表示移动到rgb_std的device上
    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)  # 此时img为hwc
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))  # 此时per后为chw

  PIL图像用可以用img.size返回其(weigh,height),与tensor正好相反!!进行逆std与mean运算时应为(h,w,c),tensor中为(bs,c,h,w),进行ToPILImage操作时应为(c,h,w)

  应用,注:使用d2l.plt.imshow()传入一个pil图像可以直接显示:

# 加载图像并进行预处理
img_tensor = transforms.ToTensor()(Image.open('img/rainier.jpg'))
# 将张量转换为 PIL 图像对象
img_pil = transforms.ToPILImage()(img_tensor)
# 在 Jupyter Notebook 中显示图像
d2l.plt.imshow(img_pil)
img_pil.size

3.抽取图像特征与net构件

pretrained_net = torchvision.models.vgg19(pretrained=True)
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
net = nn.Sequential(*[pretrained_net.features[i] for i in
                    range(max(content_layers + style_layers) + 1)])

  保留了0-28层,后面的不要了,上一节children嵌套list讲过了。

  逐层计算,并保留内容层和风格层的输出fmap:

def extract_features(X, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        X = net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents, styles

def get_contents(image_shape, device):
    content_X = preprocess(content_img, image_shape).to(device)
    contents_Y, _ = extract_features(content_X, content_layers, style_layers)
    return content_X, contents_Y
def get_styles(image_shape, device):
    style_X = preprocess(style_img, image_shape).to(device)
    _, styles_Y = extract_features(style_X, content_layers, style_layers)
    return style_X, styles_Y

4.定义损失

  内容损失直接用生成图片内容与真正图片内容进行均方误差。
  样式一样怎么理解?通过RGB直方图来理解,每个像素值没必要一样,匹配的是通道之间相关性。这里使用格拉姆矩阵,在gram函数中实现。传入X为(b,c,h,w),首先拉成(c,hw)尺寸,其中每一行可看作x1,x2...xc为长度为hw的向量,表示在该通道上的风格特征,进行矩阵相乘格拉姆矩阵中,i行和j列的元素Xij表示上述i行和j行的内积,表示通道i\j之间的风格相关性。
  为了避免风格损失受其中某较大误差值的影响,return的格拉姆矩阵再除矩阵中元素的个数chw(设bs=1)

def content_loss(Y_hat, Y):
    # 我们从动态计算梯度的树中分离⽬标:
    # 这是⼀个规定的值,⽽不是⼀个变量。
    return torch.square(Y_hat - Y.detach()).mean()

def gram(X):
    num_channels, n = X.shape[1], X.numel() // X.shape[1]
    X = X.reshape((num_channels, n))
    return torch.matmul(X, X.T) / (num_channels * n)

def style_loss(Y_hat, gram_Y):
    return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

def tv_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
    torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

  最后tv_loss是降噪操作,尽可能使临近的像素值相似,避免特别明亮或过暗的像素噪点。

  赋予权重并写总loss计算函数:

content_weight, style_weight, tv_weight = 1, 1e3, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # 分别计算内容损失、⻛格损失和全变分损失
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # 对所有损失求和
    l = sum(10 * styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l

  赋予权重,再相加。
  sum中传入3个列表则将列表里面的所有元素都相加起来,返回一个值。

5.训练

  明确一点,训练的使X,使图片,不是参数。
  所以写一个SynthesizedImage的类,将合成图像视为参数,这样就可以利用自动计算梯度更新合成图像,注意这里合成图像一开始用rand随机生成。

class SynthesizedImage(nn.Module):
    def __init__(self, img_shape, **kwargs):
        super(SynthesizedImage, self).__init__(**kwargs)
        self.weight = nn.Parameter(torch.rand(*img_shape))
    
    def forward(self):
        return self.weight

  定义inits函数,创建合成图像,在此可以将内容或风格图像当作初始权重,这样就不用随机rand生成了,使用的使data.copy_(X.data)的命令。
  然后,再将原始风格图像的gram提前计算出。

def get_inits(X, device, lr, styles_Y):
    gen_img = SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
    styles_Y_gram = [gram(Y) for Y in styles_Y]
    return gen_img(), styles_Y_gram, trainer

  训练函数,真正backward的只有l,其他的三个分loss画图print用的:

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs],
                            legend=['content', 'style', 'TV'],
                            ncols=2, figsize=(7, 2.5))
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat, styles_Y_hat = extract_features(
            X, content_layers, style_layers)
        contents_l, styles_l, tv_l, l = compute_loss(
            X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            animator.axes[1].imshow(postprocess(X))
            animator.add(epoch + 1, [float(sum(contents_l)),
            float(sum(styles_l)), float(tv_l)])
    return X

device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值