import tensorflow as tf
tf.compat.v1.app.flags.DEFINE_integer('batch_size', 100, help='The size of each batch.')
tf.compat.v1.app.flags.DEFINE_integer('epoch', 10000, help='The epoch of training.')
tf.compat.v1.app.flags.DEFINE_string('saved_models', './saved_models', help='The path of saved models during training.')
tf.compat.v1.app.flags.DEFINE_float('lr', 0.0001, help='The learning rate.')
FLAGS = tf.compat.v1.app.flags.FLAGS
def main(args):
print('The batch size: ', FLAGS.batch_size)
print('The epoch of training: ', FLAGS.epoch)
print('The path of saved models: ', FLAGS.saved_models)
print('The learning rate: ', FLAGS.lr)
if __name__ == '__main__':
tf.compat.v1.app.run(main)
Tensorflow配置参数
最新推荐文章于 2022-09-27 09:04:51 发布