python中的utils模块_Python utils 模块,TextLoader() 实例源码 - 编程字典

该段代码定义了一个名为`train`的函数,用于训练一个文本加载器加载的数据。函数首先检查`init_from`参数指定的路径是否存在,并加载旧模型的配置和词汇。接着,它比较新模型的参数与旧模型的参数是否一致,确保兼容性。然后,创建并初始化模型,在TensorFlow会话中恢复或初始化权重。训练过程中,每轮迭代都会更新学习率,打印训练损失和每批处理的时间。最后,按照设定的保存间隔保存模型。
摘要由CSDN通过智能技术生成

def train(args):

data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)

args.vocab_size = data_loader.vocab_size

# check compatibility if training is continued from previously saved model

if args.init_from is not None:

# check if all necessary files exist

assert os.path.isdir(args.init_from),"%smust be a a path" % args.init_from

assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path%s"%args.init_from

assert os.path.isfile(os.path.join(args.init_from,"words_vocab.pkl")),"words_vocab.pkl.pkl file does not exist in path%s" % args.init_from

ckpt = tf.train.get_checkpoint_state(args.init_from)

assert ckpt,"No checkpoint found"

assert ckpt.model_checkpoint_path,"No model path found in checkpoint"

# open old config and check if models are compatible

with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:

saved_model_args = cPickle.load(f)

need_be_same=["model","rnn_size","num_layers","seq_length"]

for checkme in need_be_same:

assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme

# open saved vocab/dict and check if vocabs/dicts are compatible

with open(os.path.join(args.init_from, 'words_vocab.pkl'), 'rb') as f:

saved_words, saved_vocab = cPickle.load(f)

assert saved_words==data_loader.words, "Data and loaded model disagreee on word set!"

assert saved_vocab==data_loader.vocab, "Data and loaded model disagreee on dictionary mappings!"

with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:

cPickle.dump(args, f)

with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'wb') as f:

cPickle.dump((data_loader.words, data_loader.vocab), f)

model = Model(args)

with tf.Session() as sess:

tf.initialize_all_variables().run()

saver = tf.train.Saver(tf.all_variables())

# restore model

if args.init_from is not None:

saver.restore(sess, ckpt.model_checkpoint_path)

for e in range(args.num_epochs):

sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))

data_loader.reset_batch_pointer()

state = sess.run(model.initial_state)

for b in range(data_loader.num_batches):

start = time.time()

x, y = data_loader.next_batch()

feed = {model.input_data: x, model.targets: y, model.initial_state: state}

train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)

end = time.time()

print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \

.format(e * data_loader.num_batches + b,

args.num_epochs * data_loader.num_batches,

e, train_loss, end - start))

if (e * data_loader.num_batches + b) % args.save_every == 0 \

or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result

checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')

saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)

print("model saved to {}".format(checkpoint_path))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值