(二)风格迁移——代码

本文介绍了风格迁移的实现效果,并提供了实操步骤。首先展示了风格迁移前后对比,然后详细说明了如何获取模型,包括从网盘下载并放入工程文件。接着,指导读者创建必要的文件夹,编写models.py, settings.py 和 train.py 文件,最后运行 train.py 文件以完成风格迁移操作。" 108966899,9473028,JavaWeb POST请求与Servlet处理及Bug修复记录,"['Java', 'Spring', 'Web开发', 'Maven', '数据库']
摘要由CSDN通过智能技术生成

想看原理的可以看这篇文章:https://blog.csdn.net/weixin_41108515/article/details/103650964

本文主要参考:https://blog.csdn.net/weixin_41108515/article/details/103651784

一、效果

在正式开始前,先看看效果。
风格图:(style文件中的painting.jpg)在这里插入图片描述
我想要进行风格迁移的图:(content中的qd.jpg)
在这里插入图片描述
效果图:(output文件夹里)
在这里插入图片描述

二、实操

在看完成果之后,我们开始进入实操部分。

1.获取模型
网盘:链接:https://pan.baidu.com/s/1Z3TvZvGUyTsMGQoaQv4kdQ 密码:xgwc
在这里插入图片描述
下载完成后,将其拖入工程文件中

2.写代码
2.1 我们要在工程文件下建名为「style」「content」「output」的文件夹,这些文件夹与代码文件的位置如图
在这里插入图片描述
2.2 正式写代码

新建models.py

import tensorflow as tf
import numpy as np
import settings
import scipy.io
import scipy.misc


class Model(object):
    def __init__(self, content_path, style_path):
        self.content = self.loadimg(content_path)  # 加载内容图片
        self.style = self.loadimg(style_path)  # 加载风格图片
        self.random_img = self.get_random_img()  # 生成噪音内容图片
        self.net = self.vggnet()  # 建立vgg网络

    def vggnet(self):
        # 读取预训练的vgg模型
        vgg = scipy.io.loadmat(settings.VGG_MODEL_PATH)
        vgg_layers = vgg['layers'][0]
        net = {
   }
        # 使用预训练的模型参数构建vgg网络的卷积层和池化层
        # 全连接层不需要
        # 注意,除了input之外,这里参数都为constant,即常量
        # 和平时不同,我们并不训练vgg的参数,它们保持不变
        # 需要进行训练的是input,它即是我们最终生成的图像
        net['input'] = tf.Variable(np.zeros([1, settings.IMAGE_HEIGHT, settings.IMAGE_WIDTH, 3]), dtype=tf.float32)
        # 参数对应的层数可以参考vgg模型图
        net['conv1_1'] = self.conv_relu(net['input'], self.get_wb(vgg_layers, 0))
        net['conv1_2'] = self.conv_relu(net['conv1_1'], self.get_wb(vgg_layers, 2))
        net['pool1'] = self.pool(net['conv1_2'])
        net['conv2_1'] = self.conv_relu(net['pool1'], self.get_wb(vgg_layers, 5))
        net['conv2_2'] = self.conv_relu(net['conv2_1'], self.get_wb(vgg_layers, 7))
        net['pool2'] = self.pool(net['conv2_2'])
        net['conv3_1'] = self.conv_relu(net['pool2'], self.get_wb(vgg_layers, 10))
        net['conv3_2'] = self.conv_relu(net['conv3_1'], self.get_wb(vgg_layers, 12
  • 0
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值