WGAN-GP应用于一维时序信号

前言

近年来生成对抗网络在各大顶会上大放异彩,但是大多研究集中在图像方面,衍生出一系列DCGAN、WGAN等等模型,最近我在关于手语识别的研究中遇到数据量及数据种类过少的问题,故想到使用GAN来生成手语数据,达到以假乱真的效果。除了这个方面外,GAN在这些信号的生成的研究对于医疗方面受损信号的恢复、意图生成特定的属性的信号方面也具有重要的意义。

学习路线(选读)

在研究中我走了许多弯路、耗费了大量的时间,现在将这一过程记录下来,为初学者提供一个参考路线,也提醒以后我在研究中尽量少走弯路。

  1. 决定使用GAN后可以先看这个领域的近几年的综述文献。重点关注前人总结的哪些模型适用于那些应用场景,然后针对主流模型,寻找相关论文、博客、GitHub等进行了解
  2. 需要了解的部分首先是算法原理,如果时间宽裕,最好掌握详尽的数学推导,这大大有助于代码理解和修改工作。这也是为什么要先研究主流模型的原因:相关资料较多。
  3. 去Github找相关源码进行论文复现工作。期间会遇到很多很多bug,不要怕,根据你掌握的原理进行修改,这一阶段也是你的工程能力大大提升的阶段。
  4. 大致掌握代码后,改用自己的数据集、针对自己的数据尝试修改模型,当前目标是跑通网络。
  5. 跑通后可能你的数据结果不尽人意,此时可以浏览各种调参技巧的相关文章,看看自己的模型算法哪里需要改进,深度学习仍然是一项一经验为主导的任务,前人总结的调参技巧对于我们入门来说至关重要!!!
  6. 在调参,修改bug的过程中,可以将自己遇到的bug及解决方法记录下来,总结经验,这有助于工程能力进一步提高。也可以在该项目完成后进行总结,但是效果肯定没有前者好。

GAN的发展历程

首先对GAN的雏形进行简要介绍。
众所周知:GAN是一种无监督学习方法。有两个网络分别是生成器(G)和判别器(D),二者是一个不断博弈达到平衡的过程,对于不同的任务目标,不一定要分为训练集和测试集。本次我的任务就是单纯的生成手语数据,故在该项目中我没有进行分割。
原始的GAN的优化目标函数为:

min⁡Gmax⁡DV(D,G)=Ex∼p data (x)[log⁡D(x)]+Ez∼pz(z)[log⁡(1−D(G(z)))] \min _{G} \max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text { data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))] GminDmaxV(D,G)=Exp data (x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

为了改进GAN的一些不稳定、模式崩溃等缺点,提出的WGAN能有效解决上述问题。
WGAN主要改进如下:
1、使用Wasserstein距离代替JS散度,有效的解决了当生成数据和原始数据不重叠的问题。
2、判别器D最后一层去掉sigmoid,这一点是理所应当的,我要生成的手语信号肯定不是【0,1】之间的数
3、G和D的loss不取log
4、添加了惩罚项,每次更新 D 的参数之后,将其绝对值截断到不超过一个固定常数 c ,即gradient clipping

关于这一改进的优势,有很多较好的博文给出了解释,在这里仅作简要介绍,想要了解的同学可以参考:郑华滨老师在知乎上的分析https://zhuanlan.zhihu.com/p/25071913

然而上述的WGAN实际应用中也是由效果不好的时候、所以我们今天的主角WGAN-GP就登场了!!
WGAN-GP主要针对上述第四点的惩罚项进行改进,改进之后的目标函数是:

L=Ex~∼Pg[D(x~)]−Ex∼Pr[D(x)]+λEx^∼Pα^[(∥∇x^D(x^)∥2−1)2] L=\underset{\tilde{\boldsymbol{x}} \sim \mathbb{P}_{g}}{\mathbb{E}}[D(\tilde{\boldsymbol{x}})]-\underset{\boldsymbol{x} \sim \mathbb{P}_{r}}{\mathbb{E}}[D(\boldsymbol{x})]+\lambda \underset{\hat{\boldsymbol{x}} \sim \mathbb{P}_{\hat{\alpha}}}{\mathbb{E}}\left[\left(\left\|\nabla_{\hat{\boldsymbol{x}}} D(\hat{\boldsymbol{x}})\right\|_{2}-1\right)^{2}\right] L=x~PgE[D(x~)]xPrE[D(x)]+λx^Pα^E[(x^D(x^)21)2]

数据集介绍

我们生成的手语信号是一维时序信号,具有较强的时间相关性。长度是160,单种手语样本数是300,共80种手语。

默认参数介绍

DIM = 64            #模型深度,该参数仅在调试模型时使用,调试更方便
BATCH_SIZE = 64     #每次输入batch的大小
CRITIC_ITERS = 5    #生成器迭代5次,判别器迭代1次,这是WGAN论文中提出的技巧
LAMBDA = 10         #惩罚项系数λ的默认值
epoch = 10000       #迭代次数
OUTPUT_DIM = 160    #输出数据的长度

一些注意事项

1、生成器和判别器是基本对称相反的
2、生成其中使用了一维反卷积(nn.ConvTransoise1d)、BatchNorm1d和LeakyReLU激活函数
判别器种使用了一维卷积(nn.Conv1d)及LeakyReLU函数,没有BatchNorm层(据说不可以加)
4、优化器使用的是Adam
3、很多人反映Loss会出现nan的情况,这种情况可以调节优化器、反向传播等等

首次写个人博客,如有不正确之处,欢迎指正。也希望大家多多提问,我后续注意补充。

### 使用 WGAN-GP 生成一维数据的方法 #### 方法概述 为了生成高质量的一维滚动轴承振动数据样本,采用 Wasserstein GAN with Gradient Penalty (WGAN-GP) 是一种有效的方式。该模型通过引入梯度惩罚项来稳定训练过程并提升生成效果[^1]。 #### 数据准备 以西储大学(CWRU)数据集为例,在实际应用中可以根据需求替换成其他类型的数据源。对于特定故障类型的模拟,可以通过调整不平衡比例参数来自定义不同条件下的仿真环境设置。 #### 网络结构设计 所使用的 WGAN-GP 架构可以灵活更改为 DCGANWGAN、LSGAN 或 SNGAN 等变体形式,具体取决于应用场景的需求以及预期性能指标的要求。 #### 训练过程 `train_gan` 下面是一个简化版的 Python 实现用于说明如何构建和训练一个简单的 WGAN-GP 来生成一维时间序列: ```python import torch from torch import nn, optim from torchvision.utils import save_image class Generator(nn.Module): def __init__(self, input_dim=100, output_dim=28*28): super(Generator, self).__init__() self.model = nn.Sequential( *[ nn.Linear(input_dim, 128), nn.ReLU(), nn.Linear(128, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Linear(512, output_dim), nn.Tanh() ] ) def forward(self, z): img = self.model(z) return img.view(img.size(0), -1) class Discriminator(nn.Module): def __init__(self, input_dim=28*28): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(input_dim, 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), ) def forward(self, img): validity = self.model(img) return validity.squeeze() def compute_gradient_penalty(disc, real_samples, fake_samples): """Calculates the gradient penalty loss for WGAN GP""" Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor alpha = Tensor(np.random.random((real_samples.size(0), 1))) interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) d_interpolates = disc(interpolates) gradients = autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(d_interpolates).to(real_samples.device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.reshape(gradients.shape[0], -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() return gradient_penalty # 初始化生成器与判别器 generator = Generator().cuda() discriminator = Discriminator().cuda() optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.9)) optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.9)) for epoch in range(num_epochs): for i, data in enumerate(train_loader): # Train Discriminator optimizer_D.zero_grad() real_imgs = Variable(data['data'].type(FloatTensor)) # Generate a batch of images z = Variable(Tensor(np.random.normal(0, 1, (batch_size, latent_dim)))) gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator gp = compute_gradient_penalty(discriminator, real_imgs.data, gen_imgs.data) d_loss = (-torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(gen_imgs))) + lambda_gp * gp d_loss.backward(retain_graph=True) optimizer_D.step() # Clip weights of discriminator for p in discriminator.parameters(): p.data.clamp_(-clip_value, clip_value) # Every n_critic times update Generator once if i % opt.n_critic == 0: optimizer_G.zero_grad() # Generate new set of noise samples and generate images from them z = Variable(Tensor(np.random.normal(0, 1, (batch_size, latent_dim)))) gen_imgs = generator(z) g_loss = -torch.mean(discriminator(gen_imgs)) g_loss.backward() optimizer_G.step() ``` 此代码片段展示了基本框架下 WGAN-GP 的实现方式,其中包含了生成器和判别器的设计思路及其损失函数计算逻辑,并加入了梯度惩罚机制以确保 Lipschitz 连续性约束得到满足。 #### 测试阶段 `generate_gan` 当拥有了预训练好的权重文件之后,则可以直接加载这些参数来进行新样本的合成操作而无需重新经历整个耗时较长的学习周期。以下是利用已保存下来的模型状态字典恢复网络配置并执行预测任务的具体做法: ```python checkpoint_path = 'path_to_your_pretrained_weights.pth' pretrained_dict = torch.load(checkpoint_path)['model_state_dict'] generator.load_state_dict(pretrained_dict) with torch.no_grad(): fixed_noise = torch.randn(batch_size, nz, device=device) generated_data = generator(fixed_noise).cpu().numpy() ``` 上述脚本读取之前存储下来的最佳 checkpoint 文件路径,从中提取出对应的 model state dict 并将其赋给当前实例化的生成器对象;随后借助随机噪声向量作为输入调用前馈传播完成最终输出结果获取工作。
评论 17
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值