GPU
# Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') targets = ids[:, (i+1):(i+1)+seq_length].to(device)
argparse
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--content', type=str, default='png/content.png') parser.add_argument('--style', type=str, default='png/style.png') parser.add_argument('--max_size', type=int, default=400) parser.add_argument('--total_step', type=int, default=2000) parser.add_argument('--log_step', type=int, default=10) parser.add_argument('--sample_step', type=int, default=500) parser.add_argument('--style_weight', type=float, default=100) parser.add_argument('--lr', type=float, default=0.003) config = parser.parse_args() print(config) main(config)
命令行选项、参数和子命令解析器
可以用 python xxx.py --help查看具体参数信息Check Point
def checkpoint(epoch): model_path = 'model_epoch_{}.pt'.format(epoch) state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch':epoch, 'lr':lr} torch.save(state, model_path) print('Checkpoint saved to {}'.format(model_path))
https://www.cnblogs.com/zgqcn/p/14015720.html
拼接
Pytorch Personal Tips
最新推荐文章于 2024-05-31 23:11:38 发布