Pytorch Personal Tips

GPU
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

targets = ids[:, (i+1):(i+1)+seq_length].to(device)
argparse
if __name__ == "__main__":
   parser = argparse.ArgumentParser()
   parser.add_argument('--content', type=str, default='png/content.png')
   parser.add_argument('--style', type=str, default='png/style.png')
   parser.add_argument('--max_size', type=int, default=400)
   parser.add_argument('--total_step', type=int, default=2000)
   parser.add_argument('--log_step', type=int, default=10)
   parser.add_argument('--sample_step', type=int, default=500)
   parser.add_argument('--style_weight', type=float, default=100)
   parser.add_argument('--lr', type=float, default=0.003)
   config = parser.parse_args()
   print(config)
   main(config)

命令行选项、参数和子命令解析器
可以用 python xxx.py --help查看具体参数信息

Check Point
def checkpoint(epoch):
   model_path = 'model_epoch_{}.pt'.format(epoch)
   state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch':epoch, 'lr':lr}
   torch.save(state, model_path)
   print('Checkpoint saved to {}'.format(model_path))

https://www.cnblogs.com/zgqcn/p/14015720.html

拼接

PyTorch函数解释:cat、stack、transpose、permute、squeeze、unsqueeze

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值