如何实现在colab上面修改代码的功能

 

在实验室白嫖谷歌的colab跑代码的时候发现需要修改从github上搬过来的内容,找了一系列的方案最终决定使用创建一个新的.py文件之后再将源文件的代码导出,复制进去,具体的流程从代码中看。

//从github上下载下来源代码
!git clone https://github.com/yuanhangzhang98/ml_quantum_compiling.git
//显示内部的详细文件结构
!ls
//自己写入一个新的.py文件
%%writefile code.py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from tqdm import trange
import torch

from model import Model
from agent import Agent
from system import System
from randomStateDataset import RandomStateDataset

if __name__ == '__main__':
    
    num_epoch = 300
    batch_size = 1000
    cur_length = 5
    full_dataset_length = 11
    max_length = 50
    update_interval = 100
    num_samples = batch_size * update_interval
    loss_tolerance = 0.01
    accuracy_tolerance = 0.001
    result_dir = 'results/'
    ckpt_dir = 'ckpts/'
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    policy_net = Model(embedding_size=5000, hidden_size=1000).to(device)
    target_net = Model(embedding_size=5000, hidden_size=1000).to(device)
    
    # policy_net.load_state_dict(torch.load(ckpt_dir+'model_{}_{}.ckpt'.format(num_epoch, cur_length), map_location=device))
    
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()
    
    f = open(result_dir + 'loss.txt', 'w')
    
    env = System(device)
    agent = Agent(policy_net, target_net, env, accuracy_tolerance)
    dataset = RandomStateDataset(env, cur_length, full_dataset_length, max_length, num_samples, accuracy_tolerance)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0)

    while cur_length < max_length:   
        is_updated = 0
        for n_epoch in trange(num_epoch):
            dataset.reinitialize()
            ave_loss = 0
            for sample in dataloader:
                loss = agent.update_model(sample)
                ave_loss += loss
            ave_loss /= update_interval
            print('loss:', ave_loss, 'cur_len:', cur_length)
            f.write('{}\t{}\n'.format(cur_length, ave_loss))
            if n_epoch % 10 == 0:
                if ave_loss < loss_tolerance:
                    target_net.load_state_dict(policy_net.state_dict())
                    is_updated = 1
        if is_updated:
            cur_length += 1
            dataset.cur_length += 1
            loss_tolerance = 0.01
        else:
            loss_tolerance += 0.001
        num_epoch += 10
        torch.save(policy_net.state_dict(), ckpt_dir+'model_{}_{}.ckpt'.format(num_epoch, cur_length)) 
    f.close()
//执行新生成的文件
!python code.py

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值