gan怎么输入一维数据_用GAN生成二维样本的小例子

这篇博客讨论了GAN(生成对抗网络)中判别器输入一维数据的问题,通过分析Dev Nag的50行代码示例,指出其将整个批次作为单一输入的不寻常设计。作者解释了如何通过修改代码,让判别器接受单个二维样本以更好地学习分布,并提供了一个二维GAN的示例,包括采样和损失函数的调整。此外,博客还介绍了条件GAN(C-GAN)的概念,展示了如何利用条件信息改善生成样本的质量。
摘要由CSDN通过智能技术生成

50行GAN代码的问题

Dev Nag写的50行代码的GAN,大概是网上流传最广的,关于GAN最简单的小例子。这是一份用一维均匀样本作为特征空间(latent space)样本,经过生成网络变换后,生成高斯分布样本的代码。结构非常清晰,却有一个奇怪的问题,就是判别器(Discriminator)的输入不是2维样本,而是把整个mini-batch整体作为一个维度是batch size(代码中batch size等于cardinality)那么大的样本。也就是说判别网络要判别的不是一个一维的目标分布,而是batch size那么大维度的分布:

...

d_input_size = 100 # Minibatch size - cardinality of distributions

...

class Discriminator(nn.Module):

def __init__(self, input_size, hidden_size, output_size):

super(Discriminator, self).__init__()

self.map1 = nn.Linear(input_size, hidden_size)

self.map2 = nn.Linear(hidden_size, hidden_size)

self.map3 = nn.Linear(hidden_size, output_size)

def forward(self, x):

x = F.elu(self.map1(x))

x = F.elu(self.map2(x))

return F.sigmoid(self.map3(x))

...

D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)

...

for epoch in range(num_epochs):

for d_index in range(d_steps):

# 1. Train D on real+fake

D.zero_grad()

# 1A: Train D on real

d_real_data = Variable(d_sampler(d_input_size))

d_real_decision = D(preprocess(d_real_data))

d_real_error = criterion(d_real_decision, Variable(torch.ones(1))) # ones = true

d_real_error.backward() # compute/store gradients, but don't change params

# 1B: Train D on fake

d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))

d_fake_data = G(d_gen_input).detach() # detach to avoid training G on these labels

d_fake_decision = D(preprocess(d_fake_data.t()))

d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1))) # zeros = fake

d_fake_error.backward()

d_optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward()

for g_index in range(g_steps):

# 2. Train G on D's response (but DO NOT train D on these labels)

G.zero_grad()

gen_input = Variable(gi_sampler(minibatc

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在 MATLAB 中,你可以使用生成对抗网络(GAN)来处理二维数据GAN 是一种深度学习模型,由生成器和判别器组成,用于生成逼真的数据样本。下面是一个简单的示例,展示了如何使用 GAN生成二维数据。 首先,你需要定义生成器和判别器的网络结构。生成器负责生成伪造的数据样本,而判别器负责判断输入数据是真实样本还是伪造样本。 ```matlab% 定义生成器网络结构generator = <定义生成器网络>; % 定义判别器网络结构discriminator = <定义判别器网络>; % 定义GAN模型gan = ganNetwork(generator, discriminator); ``` 然后,你可以准备用于训练 GAN 的真实数据集。在这个例子中,假设你有一个二维数据集 `data`,其中每一行表示一个样本。 ```matlab% 准备真实数据集data = <准备真实数据集>; ``` 接下来,你可以使用 `trainNetwork` 函数来训练 GAN 模型。训练过程中,生成器和判别器会交替训练,目标是使生成生成的伪造样本越来越接近真实样本。 ```matlab% 定义训练参数numEpochs =100; miniBatchSize =64; % 训练GAN模型gan = trainNetwork(data, gan, numEpochs, miniBatchSize); ``` 最后,你可以使用训练好的生成器来生成新的二维数据样本。 ```matlab%生成新的二维数据样本numSamples =1000; generatedData = generate(gan, numSamples); ``` 这是一个简单的示例,你可以根据实际需求调整网络结构和训练参数。希望可以帮助到你!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值