#注释全部写在了代码中哦,注意仔细看
#主程序,sequence_gan.py
import numpy as np
import tensorflow as tf
import random
from dataloader import Gen_Data_loader, Dis_dataloader
from generator import Generator
from discriminator import Discriminator
from rollout import ROLLOUT
from target_lstm import TARGET_LSTM
from dataloader import StrToBytes
import pickle
import pdb
#########################################################################################
# Generator Hyper-parameters
######################################################################################
EMB_DIM = 32 # embedding dimension
HIDDEN_DIM = 32 # hidden state dimension of lstm cell
SEQ_LENGTH = 20 # sequence length
START_TOKEN = 0
PRE_EPOCH_NUM = 10 # supervise (maximum likelihood estimation) epochs 120
SEED = 88
BATCH_SIZE = 64
#########################################################################################
# Discriminator Hyper-parameters
#########################################################################################
dis_embedding_dim = 64
dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]
dis_dropout_keep_prob = 0.75
dis_l2_reg_lambda = 0.2
dis_batch_size = 64
#########################################################################################
# Basic Training Parameters
#########################################################################################
TOTAL_BATCH = 10#200
positive_file = 'save/real_data.txt' #通过oracle模型生成真实数据,保存在这个文件
negative_file = 'save/generator_sample.txt' #生成器生成的假数据保存在这个文件
eval_file = 'save/eval_file.txt' #验证集文件
generated_num = 1000#10000 #产生generated_num个样本
#通过trainable_model模型产生int(generated_num/batch_size)*batch_size多个样本,并保存在output_file文件
def generate_samples(sess, trainable_model, batch_size, generated_num, output_file):
# Generate Samples
generated_samples = []
for _ in range(int(generated_num / batch_size)):
generated_samples.extend(trainable_model.generate(sess))
with open(output_file, 'w') as fout:
for poem in generated_samples:
buffer = ' '.join([str(x) for x in poem]) + '\n'
fout.write(buffer)
def target_loss(sess, target_lstm, data_loader):
# target_loss means the oracle negative log-likelihood tested with the oracle model "target_lstm"
# For more details, please see the Section 4 in https://arxiv.org/abs/1609.05473
nll = []
data_loader.reset_pointer()
for it in range(data_loader.num_batch):
batch = data_loader.next_batch()
g_loss = sess.run(target_lstm.pretrain_loss, {target_lstm.x: batch})
nll.append(g_loss)
return np.mean(nll)
def pre_train_epoch(sess, trainable_model, data_loader):
# Pre-train the generator using MLE for one epoch
supervised_g_losses = []
data_loader.reset_pointer()
for it in range(data_loader.num_batch):
batch = data_loader.next_batch()
_, g_loss = trainable_model.pretrain_step(sess, batch)
supervised_g_losses.append(g_loss)
return np.mean(supervised_g_losses)
def main():
#使得随机数据可预测,即只要seed的值一样,后续生成的随机数都一样。
random.seed(SEED)
np.random.seed(SEED)
#刚开始没有数据,因此从状态0开始,生成到状态19,产生20个数字为一个样本
assert START_TOKEN == 0
#先创建对象
gen_data_loader = Gen_Data_loader(BATCH_SIZE)
likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
vocab_size = 5000
dis_data_loader = Dis_dataloader(BATCH_SIZE)
generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
# target_params = pickle.load(open(StrToBytes('save/target_params.pkl')))
target_params = pickle.load(open('save/target_params.pkl', 'rb'), encoding='iso-8859-1')
target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model
discriminator = Discriminator(sequence_length=20, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim,
filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
#--------------------------------mycode-------------------------------#
merged = tf.summary.merge_all() # 将图形、训练过程等数据合并在一起
writer = tf.summary.FileWriter('logs', sess.graph) # 将训练日志写入到logs文件夹下
# First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
#pdb.set_trace()
gen_data_loader.create_batches(positive_file)
log = open('save/experiment-log.txt', 'w')
#--------- pre-train generator预训练生成器--------------------#
print("Start pre-training...")
log.write('pre-training...\n')
for epoch in range(PRE_EPOCH_NUM):#120
loss = pre_train_epoch(sess, generator, gen_data_loader)
if epoch % 5 == 0:
generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
likelihood_data_loader.create_batches(eval_file)
test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
print ("pre-train epoch ", epoch, "test_loss ", test_loss)
buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
log.write(buffer)
print("Start pre-training discriminator...")
# Train 3 epoch on the generated data and do this for 50 times
for _ in range(10):
generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
#pdb.set_trace()
dis_data_loader.load_train_data(positive_file, negative_file)
for _ in range(3):
dis_data_loader.reset_pointer()
for it in range(dis_data_loader.num_batch):
x_batch, y_batch = dis_data_loader.next_batch()
feed = {
discriminator.input_x: x_batch,
discriminator.input_y: y_batch,
discriminator.dropout_keep_prob: dis_dropout_keep_prob
}
_ = sess.run(discriminator.train_op, feed)
#pdb.set_trace()
rollout = ROLLOUT(generator, 0.8)
print ("#########################################################################")
print ("Start Adversarial Training...")
log.write('adversarial training...\n')
for total_batch in range(TOTAL_BATCH):
# Train the generator for one step
for it in range(1):
samples = generator.generate(sess)
rewards = rollout.get_reward(sess, samples, 16, discriminator)
feed = {generator.x: samples, generator.rewards: rewards}
_ = sess.run(generator.g_updates, feed_dict=feed)
# Test
if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
likelihood_data_loader.create_batches(eval_file)
test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
print ("total_batch: ", total_batch, "test_loss: ", test_loss)
log.write(buffer)
# Update roll-out parameters
rollout.update_params()
# Train the discriminator
for _ in range(5):
generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
dis_data_loader.load_train_data(positive_file, negative_file)
for _ in range(3):
dis_data_loader.reset_pointer(
SeqGAN代码解析
最新推荐文章于 2020-12-05 14:09:08 发布
本文深入解析SeqGAN(序列生成对抗网络)的代码实现,涵盖了生成器和判别器的设计,以及训练过程的关键步骤。通过理解 SeqGAN 的工作原理,读者能够掌握如何运用 GANs 进行序列数据的生成任务。
摘要由CSDN通过智能技术生成