python开源库生成式对抗网络_跟我学算法-对抗生成网络

classGAN(object):

# 初始化变量def __init__(self, data, gen, num_steps, batch_size, log_every):

self.data=data

self.gen=gen

self.num_steps=num_steps

self.batch_size=batch_size

self.log_every=log_every

self.mip_hidden_size= 4self.learning_rate= 0.03self._create_model()

# 建立模型def_create_model(self):

# 建立预判别模型

with tf.variable_scope('D_pre'):

self.pre_input= tf.placeholder(tf.float32, shape=(self.batch_size, 1))

self.pre_labels= tf.placeholder(tf.float32, shape=(self.batch_size, 1))

# 获得预测结果

D_pre=discriminator(self.pre_input, self.mip_hidden_size)#预测值与真实之间的差异

self.pre_loss = tf.reduce_mean(tf.square(D_pre -self.pre_labels))

# 训练缩小预测值与真实值的差异

self.pre_opt=optimizer(self.pre_loss, None, self.learning_rate)

# 建立造假模型

with tf.variable_scope('Gen'):#伪造数据的生成

self.z = tf.placeholder(tf.float32, shape=(self.batch_size, 1))

self.G=generator(self.z, self.mip_hidden_size)

# 建立判别模型

with tf.variable_scope('Disc') as scope:

# 对真实值做预测, D1为真实值的概率

self.x= tf.placeholder(tf.float32, shape=(self.batch_size, 1))

self.D1=discriminator(self.x, self.mip_hidden_size)#变量重用

scope.reuse_variables()

# 对造假值做预测, D2为预测到造假值的概率

self.D2=discriminator(self.G, self.mip_hidden_size)

# 第一个对抗函数

self.loss_d= tf.reduce_mean(-tf.log(self.D1) - tf.log(1-self.D2))

# 第二个对抗函数

self.loss_g= tf.reduce_mean(-tf.log(self.D2))#打包参数

self.d_pre_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D_pre')

self.d_params= tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Disc')

self.g_params= tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Gen')#获得训练以后的参数

self.opt_d =optimizer(self.loss_d, self.d_params, self.learning_rate)

self.opt_g=optimizer(self.loss_g, self.g_params, self.learning_rate)

# 进行训练deftrain(self):

with tf.Session() as session:#变量初始化

tf.global_variables_initializer().run()

# 进行预处理的训练

num_pretrain_steps= 1000

for step inrange(num_pretrain_steps):

# 生成随机值

d= (np.random.random(self.batch_size) - 0.5) * 10.0

# 生成随机值的标签

labels= norm.pdf(d, loc=self.data.mu, scale=self.data.sigma)

pretrain_loss, _=session.run([self.pre_loss, self.pre_opt], {

self.pre_input : np.reshape(d, (self.batch_size,1)),

self.pre_labels : np.reshape(labels, (self.batch_size,1))

})#获得参数

self.weightsD =session.run(self.d_pre_params)#将d_pre_params 参数拷贝给 d_params

for i, v inenumerate(self.d_params):

session.run(v.assign(self.weightsD[i]))

# 进行两个对抗函数的参数训练for step inrange(self.num_steps):

# 第一个对抗函数的训练

x=self.data.sample(self.batch_size)

z=self.gen.sample(self.batch_size)

loss_d, _=session.run([self.loss_d, self.opt_d],{

self.x: np.reshape(x, (self.batch_size,1)),

self.z : np.reshape(z, (self.batch_size,1))

})

# 第二个对抗函数的训练

z=self.gen.sample(self.batch_size)

loss_g, _=session.run([self.loss_g, self.opt_g], {

self.z : np.reshape(z, (self.batch_size,1))

})

# 输出结果if step % self.log_every ==0:print('{}:{}\t{}'.format(step, loss_d, loss_g))

# 迭代一百次或者在最后一次进行画图if step % 100 == 0 or step == self.num_steps - 1:

self._plot_distributions(session)def _samples(self, session, num_points=10000, num_bins=100):

xs= np.linspace(-self.gen.range, self.gen.range, num_points)

bins= np.linspace(-self.gen.range, self.gen.range, num_bins)#data distribution # 实际数据

d =self.data.sample(num_points)

pd, _= np.histogram(d, bins=bins, density=True)#generated samples # 造假数据

zs = np.linspace(-self.gen.range, self.gen.range, num_points)

g= np.zeros((num_points, 1))for i in range(num_points //self.batch_size):

g[self.batch_size* i:self.batch_size * (i + 1)] =session.run(self.G, {

self.z: np.reshape(

zs[self.batch_size* i:self.batch_size * (i + 1)],

(self.batch_size,1)

)

})

# 返回造假数据

pg, _= np.histogram(g, bins=bins, density=True)returnpd, pg

# 画图def_plot_distributions(self, session):

pd, pg=self._samples(session)

p_x= np.linspace(-self.gen.range, self.gen.range, len(pd))

f, ax= plt.subplots(1)

ax.set_ylim(0,1)

plt.plot(p_x, pd, label='real data')

plt.plot(p_x, pg, label='generated data')

plt.title('1D Generative Adversarial Network')

plt.xlabel('Data values')

plt.ylabel('Probability density')

plt.legend()

plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值