Meta-Q-Learning代码

原文地址:https://arxiv.org/pdf/1910.00125v1.pdf

需要先配置pytorch环境,安装mujoco。

进入https://mujoco.org/点击download

根据需要下载,比如windows系统下载

mujoco-3.1.1-windows-x86_64.zip

解压后添加bin的路径到环境变量path里

在bin文件夹里点击simulate.exe,就可以进入mujoco页面了

在model文件夹里可以选择不同的模型的xml文件拖进mujoco页面里。

更多mujoco介绍可以看b站大学生help大学生:面对mujoco的勇气也没有了吗你?_哔哩哔哩_bilibili

up主讲解的很详细。

1.README.md

python run_script.py --env cheetah-dir --gpu_id 0 --seed 0

该代码可以在GPU和CPU上work。

超参数的完整列表,可以参见论文的附录。

 (要是在新环境中运行代码,首先需要在./configs/pearl_envs.json中定义一个条目。可以看./configs/abl_envs.json作为参考。另外,需要在rlekit/env/中添加env代码。)

2.algs

2.1MQL

2.1.1__init__.py

绝对路径导入、导入浮点数除法和使用python3的print函数

目的是为了确保代码具有更好的兼容性,并在不同版本的Python中都能正常运行

2.1.2buffer.py

class Buffer(object):
	def __init__(self, max_size=1e6):
		#初始化缓冲区对象
		self.storage = []#存储缓冲区数据的列表
		self.max_size = max_size#缓冲区的最大大小
		self.ptr = 0#指针,用于跟踪最新添加的数据的位置

	def reset(self):
		#重置缓冲区
		self.storage = []#清空缓冲区数据
		self.ptr = 0#指针归零

	def add(self, data):
		'''
		向缓冲区添加数据
		data ==> (state, next_state, action, reward, done, previous_action, previous_reward)
		'''
		if len(self.storage) == self.max_size:#如果缓冲区满了
			self.storage[int(self.ptr)] = data#替换最旧的数据
			self.ptr = (self.ptr + 1) % self.max_size#更新指针位置,循环回缓冲区开头
		else:
			self.storage.append(data)#没满将数据添加到缓冲区末尾

	def size_rb(self):
		#返回缓冲区的数据量
		if len(self.storage) == self.max_size:
			return self.max_size
		else:
			return len(self.storage)

	def sample(self, batch_size):
		'''
		从缓冲区中随机抽样一批数据
			Returns tuples of (state, next_state, action, reward, done,
							  previous_action, previous_reward, previous_state
							  next_actions, next_rewards, next_states
							  )
		'''
		ind = np.random.randint(0, len(self.storage), size=batch_size)#从缓冲区里随机选择索引
		x, y, u, r, d, pu, pr, px, nu, nr, nx = [], [], [], [], [], [], [], [], [], [], []

		for i in ind:
			#解包,将数据结构中的元素解开并赋值给变量
			# state, next_state, action, reward, done, previous_action, previous_reward, previous_state,
			# next_actions, next_rewards, next_states
			# X ==> state, 
			# Y ==> next_state
			# U ==> action
			# r ==> reward
			# d ==> done
			# pu ==> previous action
			# pr ==> previous reward
			# px ==> previous state
			# nu ==> next actions
			# nr ==> next rewards
			# nx ==> next states

			X, Y, U, R, D, PU, PR, PX, NU, NR, NX = self.storage[i]
			#将数据添加到对应的数组中
			x.append(np.array(X, copy=False))
			y.append(np.array(Y, copy=False))
			u.append(np.array(U, copy=False))
			r.append(np.array(R, copy=False))
			d.append(np.array(D, copy=False))
			pu.append(np.array(PU, copy=False))
			pr.append(np.array(PR, copy=False))
			px.append(np.array(PX, copy=False))
			nu.append(np.array(NU, copy=False))
			nr.append(np.array(NR, copy=False))
			nx.append(np.array(NX, copy=False))
		#将数组转换为NumPy数组并返回(占用内存少,数据处理效率高)
		return np.array(x), np.array(y), np.array(u), \
			   np.array(r).reshape(-1, 1), np.array(d).reshape(-1, 1), \
			   np.array(pu), np.array(pr), np.array(px),\
			   np.array(nu), np.array(nr), np.array(nx)

2.1.3mql.py

from __future__ import  print_function, division
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from copy import deepcopy
from sklearn.linear_model import LogisticRegression as logistic

class MQL:

    def __init__(self, 
                actor,#actor网络
                actor_target,#actor目标网络
                critic,#critic网络
                critic_target,#critic目标网络
                lr=None,#RMSProp的学习率
                gamma=0.99,#奖励折扣因子
                ptau = 0.005,#Polyak平均的插值因子(target policy smoothing,它解决了DDPG中可能发生的特定故障:如果Q函数逼近器为某些操作产生了不正确的尖峰,该策略将迅速利用该峰,并出现脆性或错误行为。 可以通过在类似action上使Q函数变得平滑来修正,即target policy smoothing)
                policy_noise = 0.2,#对策略添加的噪声
                noise_clip = 0.5,#TD3在action的每个维度上都添加了clipped noise,从而使目标动作被限定在有效动作范围内
                policy_freq = 2,#延迟的策略更新频率
                batch_size = 100,#批量大小
                optim_method = '',#优化方法
                max_action = None,#动作的最大值
                max_iter_logistic = 2000,#逻辑回归的最大迭代次数
                beta_clip = 1,#beta的剪切范围
                enable_beta_obs_cxt = False,#决定是否将观测值和上下文连接起来进行逻辑回归
                prox_coef = 1,#逻辑回归的回归系数
                device = 'cpu',#cpu or cuda
                lam_csc = 1.0,#逻辑回归的正则化系数,系数越小正则化越强
                type_of_training = 'csc',#训练类型,如'csc'(covariate shift correction)
                use_ess_clipping = False,#是否使用ESS(effective sample size)剪切
                use_normalized_beta = True,#是否使用归一化的beta
                reset_optims = False,#是否重置优化器
                ):

        '''
            actor:  actor network 
            critic: critic network 
            lr:   learning rate for RMSProp
            gamma: reward discounting parameter
            ptau:  Interpolation factor in polyak averaging  
            policy_noise: add noise to policy 
            noise_clip: clipped noise 
            policy_freq: delayed policy updates
            enable_beta_obs_cxt:  decide whether to concat obs and ctx for logistic regresstion
            lam_csc: logisitc regression reg, samller means stronger reg
        '''
        self.actor = actor
        self.actor_target = actor_target
        self.critic = critic
        self.critic_target = critic_target
        self.gamma = gamma
        self.ptau = ptau
        self.policy_noise = policy_noise
        self.policy_freq  = policy_freq
        self.noise_clip = noise_clip
        self.max_action = max_action
        self.batch_size = batch_size
        self.max_iter_logistic = max_iter_logistic
        self.beta_clip = beta_clip
        self.enable_beta_obs_cxt = enable_beta_obs_cxt
        self.prox_coef = prox_coef
        self.prox_coef_init = prox_coef
        self.device = device
        self.lam_csc = lam_csc
        self.type_of_training = type_of_training
        self.use_ess_clipping = use_ess_clipping
        self.r_eps = np.float32(1e-7)  #这用于避免计算中的 INF 或 NAN
        self.use_normalized_beta = use_normalized_beta
        self.set_training_style()
        self.lr = lr
        self.reset_optims = reset_optims


        # 加载tragtes模型
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_target.load_state_dict(self.critic.state_dict())

        #保存模型参数的副本用于近端点更新
        self.copy_model_params()

        if lr:
            self.actor_optimizer = optim.Adam(self.actor.parameters(), lr = lr)#使用指定学习率的优化器
            self.critic_optimizer = optim.Adam(self.critic.parameters(), lr = lr)

        else:
            self.actor_optimizer = optim.Adam(self.actor.parameters())#使用默认学习率的优化器
            self.critic_optimizer = optim.Adam(self.critic.parameters())

        print('-----------------------------')
        print('Optim Params')
        print("Actor:\n ",  self.actor_optimizer)
        print("Critic:\n ", self.critic_optimizer )
        print('********')
        print("reset_optims: ", reset_optims)
        print("use_ess_clipping: ", use_ess_clipping)
        print("use_normalized_beta: ", use_normalized_beta)
        print("enable_beta_obs_cxt: ", enable_beta_obs_cxt)
        print('********')
        print('-----------------------------')

    def copy_model_params(self):
        '''
        保存actor和critic的参数副本,用于近端点更新
        '''
        self.ckpt = {
                        'actor': deepcopy(self.actor),
                        'critic': deepcopy(self.critic)
                    }

    def set_tasks_list(self, tasks_idx):
        '''
        保存任务列表的副本
        '''
        self.train_tasks_list = set(tasks_idx.copy())

    def select_action(self, obs, previous_action, previous_reward, previous_obs):

        '''
        返回动作
        '''
        obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device)
        previous_action = torch.FloatTensor(previous_action.reshape(1, -1)).to(self.device)
        previous_reward = torch.FloatTensor(previous_reward.reshape(1, -1)).to(self.device)
        previous_obs = torch.FloatTensor(previous_obs.reshape(1, -1)).to(self.device)

        # 将所有数据组合在一起,然后发送给actor
        # torch.cat([previous_action, previous_reward], dim = -1)
        #将action和reward拼接
        pre_act_rew = [previous_action, previous_reward, previous_obs]

        return self.actor(obs, pre_act_rew).cpu().data.numpy().flatten()

    def get_prox_penalty(self, model_t, model_target):
        '''
        计算||theta - theta_t||
        用于计算近端点惩罚项,它遍历两个模型model_t和model_target的参数,并计算它们之间的差的范数的平方,然后将所有差的范数平方相加得到最终的近端点惩罚项。
        '''
        param_prox = []
        for p, q in zip(model_t.parameters(), model_target.parameters()):
            # q should ne detached
            param_prox.append((p - q.detach()).norm()**2)

        result = sum(param_prox)

        return result

    def train_cs(self, task_id = None, snap_buffer = None, train_tasks_buffer = None, adaptation_step = False):
        '''
        用于训练csc模型(用于纠正不同任务之间的数据分布差异,以提高强化学习的性能),从缓冲区中获取当前任务的样本和其他任务的样本,并将它们视为两个类别。然后,它提取上下文特征,并使用这些特征训练一个逻辑回归分类器。最后,返回训练得到的分类器模型和一些信息。
        snap_buffer:储存适应任务之前经验的缓冲区
        train_tasks_buffer:储存训练任务的缓冲区
        '''

        ######
        #获取所有数据
        ######
        if adaptation_step == True:
            # 步骤1:计算每个类别需要多少样本
            #在适应阶段所有的训练任务都可以使用
            task_bsize = int(snap_buffer.size_rb(task_id) / (len(self.train_tasks_list))) + 2
            neg_tasks_ids = self.train_tasks_list

        else:
            task_bsize = int(snap_buffer.size_rb(task_id) / (len(self.train_tasks_list) - 1)) + 2
            neg_tasks_ids = list(self.train_tasks_list.difference(set([task_id])))

        #从train_tasks_buffer中采样负样本
        # view --> len(neg_tasks_ids),task_bsize, D ==> len(neg_tasks_ids) * task_bsize, D
        pu, pr, px, xx = train_tasks_buffer.sample(task_ids = neg_tasks_ids, batch_size = task_bsize)
        neg_actions = torch.FloatTensor(pu).view(task_bsize * len(neg_tasks_ids), -1).to(self.device)
        neg_rewards = torch.FloatTensor(pr).view(task_bsize * len(neg_tasks_ids), -1).to(self.device)
        neg_obs = torch.FloatTensor(px).view(task_bsize * len(neg_tasks_ids), -1).to(self.device)
        neg_xx = torch.FloatTensor(xx).view(task_bsize * len(neg_tasks_ids), -1).to(self.device)

        #从snap_buffer中采样正样本
        # returns size: (task_bsize, D)
        ppu, ppr, ppx, pxx = snap_buffer.sample(task_ids = [task_id], batch_size = snap_buffer.size_rb(task_id))
        pos_actions = torch.FloatTensor(ppu).to(self.device)
        pos_rewards = torch.FloatTensor(ppr).to(self.device)
        pos_obs = torch.FloatTensor(ppx).to(self.device)
        pos_pxx = torch.FloatTensor(pxx).to(self.device)

        #将奖励、动作与之前的状态组合起来,用于context network
        pos_act_rew_obs  = [pos_actions, pos_rewards, pos_obs]
        neg_act_rew_obs  = [neg_actions, neg_rewards, neg_obs]

        ######
        #提取context特征
        ######
        with torch.no_grad():   

            # batch_size X context_hidden 
            # self.actor.get_conext_feats outputs, [batch_size , context_size]
            # torch.cat ([batch_size , obs_dim], [batch_size , context_size]) ==> [batch_size, obs_dim + context_size ]
            if self.enable_beta_obs_cxt == True:#将正样本(负样本)的观测值和上下文连接起来进行逻辑回归
                snap_ctxt = torch.cat([pos_pxx, self.actor.get_conext_feats(pos_act_rew_obs)], dim = -1).cpu().data.numpy()
                neg_ctxt = torch.cat([neg_xx, self.actor.get_conext_feats(neg_act_rew_obs)], dim = -1).cpu().data.numpy()

            else:#获取正样本(负样本)的上下文特征
                snap_ctxt = self.actor.get_conext_feats(pos_act_rew_obs).cpu().data.numpy()
                neg_ctxt = self.actor.get_conext_feats(neg_act_rew_obs).cpu().data.numpy()


        ######
        #训练逻辑回归分类器
        ######
        x = np.concatenate((snap_ctxt, neg_ctxt)) # [b1 + b2] X D
        y = np.concatenate((-np.ones(snap_ctxt.shape[0]), np.ones(neg_ctxt.shape[0])))

        #使用逻辑回归模型训练
        #模型参数:[1 , D],D是context_hidden维度
        model = logistic(solver='lbfgs', max_iter = self.max_iter_logistic, C = self.lam_csc).fit(x,y)
        #计算模型的预测得分(分类器的性能)
        predcition_score = model.score(x, y)
        #储存信息:正样本数量、负样本数量、模型得分
        info = (snap_ctxt.shape[0], neg_ctxt.shape[0],  model.score(x, y))
        #print(info)
        return model, info

    def update_prox_w_ess_factor(self, cs_model, x, beta=None):
        '''
            用来计算有效样本容量effective sample size (ESS):
            ESS = ||w||^2_1 / ||w||^2_2  , w = pi / beta,其中pi是类别1的概率
            ESS = ESS / n,n是要归一化的样本数
            x: is (n, D)
        '''
        n = x.shape[0]
        if beta is not None:
            #根据beta更新近端系数
            # beta  的结果应该与 cs_model.predict_proba(x)[:,0] 相同(如果没有clipping)
            w = ((torch.sum(beta)**2) /(torch.sum(beta**2) + self.r_eps) )/n
            ess_factor = np.float32(w.numpy())

        else:
            #获取类别1的概率
            #根据模型预测结果更新近端系数
            p0 = cs_model.predict_proba(x)[:,0]
            w =  p0 / ( 1 - p0 + self.r_eps)
            w = (np.sum(w)**2) / (np.sum(w**2) + self.r_eps)
            ess_factor = np.float32(w) / n

        #计算近端系数的更新因子,由于我们假设task的类别是-1,回访缓冲区的类别是1,所以:
        ess_prox_factor = 1.0 - ess_factor
        #确保近端系数为有效值,如果计算得到的 ess_prox_factor 无效(NaN、无穷大或小于等于阈值 self.r_eps),则将近端点惩罚系数 self.prox_coef 设置为初始值 self.prox_coef_init。否则,将 self.prox_coef 设置为 ess_prox_factor。
        if np.isnan(ess_prox_factor) or np.isinf(ess_prox_factor) or ess_prox_factor <= self.r_eps: # make sure that it is valid
            self.prox_coef = self.prox_coef_init

        else:
            self.prox_coef = ess_prox_factor

    def get_propensity(self, cs_model, curr_pre_act_rew, curr_obs):
        '''
            This function returns propensity for current sample of data 
            simply: exp(f(x))
        '''

        ######
        # 提取上下文特征
        ######
        with torch.no_grad():

            # batch_size X context_hidden 
            if self.enable_beta_obs_cxt == True:
                ctxt = torch.cat([curr_obs, self.actor.get_conext_feats(curr_pre_act_rew)], dim = -1).cpu().data.numpy()

            else:
                ctxt = self.actor.get_conext_feats(curr_pre_act_rew).cpu().data.numpy()

        # step 0: get f(x)
        f_prop = np.dot(ctxt, cs_model.coef_.T) + cs_model.intercept_

        # step 1: convert to torch
        f_prop = torch.from_numpy(f_prop).float()

        # To make it more stable, clip it
        f_prop = f_prop.clamp(min=-self.beta_clip)

        # step 2: exp(-f(X)), f_score: N * 1
        f_score = torch.exp(-f_prop)
        f_score[f_score < 0.1]  = 0 # for numerical stability

        if self.use_normalized_beta == True:

            #get logistic regression prediction of class [-1] for current task
            lr_prob = cs_model.predict_proba(ctxt)[:,0]
            # normalize using logistic_probs
            d_pmax_pmin = np.float32(np.max(lr_prob) - np.min(lr_prob))
            f_score = ( d_pmax_pmin * (f_score - torch.min(f_score)) )/( torch.max(f_score) - torch.min(f_score) + self.r_eps ) + np.float32(np.min(lr_prob))

        # update prox coeff with ess.
        if self.use_ess_clipping == True:
            self.update_prox_w_ess_factor(cs_model, ctxt, beta=f_score)


        return f_score, None

    def do_training(self,
                    replay_buffer = None,
                    iterations = None,
                    csc_model = None,
                    apply_prox = False,
                    current_batch_size = None,
                    src_task_ids = []):

        '''
            inputs:
                replay_buffer
                iterations episode_timesteps                 
        '''
        actor_loss_out = 0.0
        critic_loss_out = 0.0
        critic_prox_out = 0.0
        actor_prox_out = 0.0
        list_prox_coefs = [self.prox_coef]

        for it in range(iterations):

            ########
            # Sample replay buffer 
            ########
            if len(src_task_ids) > 0:
                x, y, u, r, d, pu, pr, px, nu, nr, nx = replay_buffer.sample_tasks(task_ids = src_task_ids, batch_size = current_batch_size)

            else:
                x, y, u, r, d, pu, pr, px, nu, nr, nx = replay_buffer.sample(current_batch_size)

            obs = torch.FloatTensor(x).to(self.device)
            next_obs = torch.FloatTensor(y).to(self.device)
            action = torch.FloatTensor(u).to(self.device)
            reward = torch.FloatTensor(r).to(self.device)
            mask = torch.FloatTensor(1 - d).to(self.device)
            previous_action = torch.FloatTensor(pu).to(self.device)
            previous_reward = torch.FloatTensor(pr).to(self.device)
            previous_obs = torch.FloatTensor(px).to(self.device)

            # list of hist_actions and hist_rewards which are one time ahead of previous_ones
            # example:
            # previous_action = [t-3, t-2, t-1]
            # hist_actions    = [t-2, t-1, t]
            hist_actions = torch.FloatTensor(nu).to(self.device)
            hist_rewards = torch.FloatTensor(nr).to(self.device)
            hist_obs     = torch.FloatTensor(nx).to(self.device)


            # combine reward and action
            act_rew = [hist_actions, hist_rewards, hist_obs] # torch.cat([action, reward], dim = -1)
            pre_act_rew = [previous_action, previous_reward, previous_obs] #torch.cat([previous_action, previous_reward], dim = -1)

            if csc_model is None:
                # propensity_scores dim is batch_size 
                # no csc_model, so just do business as usual 
                beta_score = torch.ones((current_batch_size, 1)).to(self.device)

            else:
                # propensity_scores dim is batch_size 
                beta_score, clipping_factor = self.get_propensity(csc_model, pre_act_rew, obs)
                beta_score = beta_score.to(self.device)
                list_prox_coefs.append(self.prox_coef)

            ########
            # Select action according to policy and add clipped noise 
            # mu'(s_t) = mu(s_t | \theta_t) + N (Eq.7 in https://arxiv.org/abs/1509.02971) 
            # OR
            # Eq. 15 in TD3 paper:
            # e ~ clip(N(0, \sigma), -c, c)
            ########
            noise = (torch.randn_like(action) * self.policy_noise ).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.actor_target(next_obs, act_rew) + noise).clamp(-self.max_action, self.max_action)

            ########
            #  Update critics
            #  1. Compute the target Q value 
            #  2. Get current Q estimates
            #  3. Compute critic loss
            #  4. Optimize the critic
            ########

            # 1. y = r + \gamma * min{Q1, Q2} (s_next, next_action)
            # if done , then only use reward otherwise reward + (self.gamma * target_Q)
            target_Q1, target_Q2 = self.critic_target(next_obs, next_action, act_rew)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + (mask * self.gamma * target_Q).detach()

            # 2.  Get current Q estimates
            current_Q1, current_Q2 = self.critic(obs, action, pre_act_rew)


            # 3. Compute critic loss
            # even we picked min Q, we still need to backprob to both Qs
            critic_loss_temp = F.mse_loss(current_Q1, target_Q, reduction='none') + F.mse_loss(current_Q2, target_Q, reduction='none')
            assert critic_loss_temp.shape == beta_score.shape, ('shape critic_loss_temp and beta_score shoudl be the same', critic_loss_temp.shape, beta_score.shape)

            critic_loss = (critic_loss_temp * beta_score).mean()
            critic_loss_out += critic_loss.item()

            if apply_prox:
                # calculate proximal term
                critic_prox = self.get_prox_penalty(self.critic, self.ckpt['critic'])
                critic_loss = critic_loss + self.prox_coef * critic_prox
                critic_prox_out += critic_prox.item()

            # 4. Optimize the critic
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            ########
            # Delayed policy updates
            ########
            if it % self.policy_freq == 0:

                # Compute actor loss
                actor_loss_temp = -1 * beta_score * self.critic.Q1(obs, self.actor(obs, pre_act_rew), pre_act_rew)
                actor_loss = actor_loss_temp.mean()
                actor_loss_out += actor_loss.item()

                if apply_prox:
                    # calculate proximal term
                    actor_prox = self.get_prox_penalty(self.actor, self.ckpt['actor'])
                    actor_loss = actor_loss + self.prox_coef * actor_prox
                    actor_prox_out += actor_prox.item()

                # Optimize the actor 
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()


                # Update the frozen target models
                for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                    target_param.data.copy_(self.ptau * param.data + (1 - self.ptau) * target_param.data)

                for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                    target_param.data.copy_(self.ptau * param.data + (1 - self.ptau) * target_param.data)

        out = {}
        if iterations == 0:
            out['critic_loss'] = 0
            out['actor_loss']  = 0
            out['prox_critic'] = 0
            out['prox_actor']  = 0
            out['beta_score']  = 0

        else:
            out['critic_loss'] = critic_loss_out/iterations
            out['actor_loss']  = self.policy_freq * actor_loss_out/iterations
            out['prox_critic'] = critic_prox_out/iterations
            out['prox_actor']  = self.policy_freq * actor_prox_out/iterations
            out['beta_score']  = beta_score.cpu().data.numpy().mean()

        #if csc_model and self.use_ess_clipping == True:
        out['avg_prox_coef'] = np.mean(list_prox_coefs)

        return out

    def train_TD3(
                self,
                replay_buffer=None,
                iterations=None,
                tasks_buffer = None,
                train_iter = 0,
                task_id = None,
                nums_snap_trains = 5):

        '''
            inputs:
                replay_buffer
                iterations episode_timesteps
            outputs:

        '''
        actor_loss_out = 0.0
        critic_loss_out = 0.0

        ### if there is no eough data in replay buffer, then reduce size of iteration to 20:
        #if replay_buffer.size_rb() < iterations or replay_buffer.size_rb() <  self.batch_size * iterations:
        #    temp = int( replay_buffer.size_rb()/ (self.batch_size) % iterations ) + 1
        #    if temp < iterations:
        #        iterations = temp

        for it in range(iterations):

            ########
            # Sample replay buffer
            ########
            x, y, u, r, d, pu, pr, px, nu, nr, nx = replay_buffer.sample(self.batch_size)
            obs = torch.FloatTensor(x).to(self.device)
            next_obs = torch.FloatTensor(y).to(self.device)
            action = torch.FloatTensor(u).to(self.device)
            reward = torch.FloatTensor(r).to(self.device)
            mask = torch.FloatTensor(1 - d).to(self.device)
            previous_action = torch.FloatTensor(pu).to(self.device)
            previous_reward = torch.FloatTensor(pr).to(self.device)
            previous_obs = torch.FloatTensor(px).to(self.device)

            # list of hist_actions and hist_rewards which are one time ahead of previous_ones
            # example:
            # previous_action = [t-3, t-2, t-1]
            # hist_actions    = [t-2, t-1, t]
            hist_actions = torch.FloatTensor(nu).to(self.device)
            hist_rewards = torch.FloatTensor(nr).to(self.device)
            hist_obs     = torch.FloatTensor(nx).to(self.device)

            # combine reward and action
            act_rew = [hist_actions, hist_rewards, hist_obs] # torch.cat([action, reward], dim = -1)
            pre_act_rew = [previous_action, previous_reward, previous_obs] #torch.cat([previous_action, previous_reward], dim = -1)

            ########
            # Select action according to policy and add clipped noise
            # mu'(s_t) = mu(s_t | \theta_t) + N (Eq.7 in https://arxiv.org/abs/1509.02971)
            # OR
            # Eq. 15 in TD3 paper:
            # e ~ clip(N(0, \sigma), -c, c)
            ########
            noise = (torch.randn_like(action) * self.policy_noise ).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.actor_target(next_obs, act_rew) + noise).clamp(-self.max_action, self.max_action)

            ########
            #  Update critics
            #  1. Compute the target Q value
            #  2. Get current Q estimates
            #  3. Compute critic loss
            #  4. Optimize the critic
            ########

            # 1. y = r + \gamma * min{Q1, Q2} (s_next, next_action)
            # if done , then only use reward otherwise reward + (self.gamma * target_Q)
            target_Q1, target_Q2 = self.critic_target(next_obs, next_action, act_rew)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + (mask * self.gamma * target_Q).detach()

            # 2.  Get current Q estimates
            current_Q1, current_Q2 = self.critic(obs, action, pre_act_rew)

            # 3. Compute critic loss
            # even we picked min Q, we still need to backprob to both Qs
            critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
            critic_loss_out += critic_loss.item()

            # 4. Optimize the critic
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            ########
            # Delayed policy updates
            ########
            if it % self.policy_freq == 0:

                # Compute actor loss
                actor_loss = -self.critic.Q1(obs, self.actor(obs, pre_act_rew), pre_act_rew).mean()
                actor_loss_out += actor_loss.item()

                # Optimize the actor
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()


                # Update the frozen target models
                for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                    target_param.data.copy_(self.ptau * param.data + (1 - self.ptau) * target_param.data)

                for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                    target_param.data.copy_(self.ptau * param.data + (1 - self.ptau) * target_param.data)

        out = {}
        out['critic_loss'] = critic_loss_out/iterations
        out['actor_loss'] = self.policy_freq * actor_loss_out/iterations

        # keep a copy of models' params
        self.copy_model_params()
        return out, None

    def adapt(self,
            train_replay_buffer = None,#训练回放缓冲区
            train_tasks_buffer = None,#训练任务缓冲区
            eval_task_buffer = None,#评估任务缓冲区
            task_id = None,#任务ID
            snap_iter_nums = 5,#snap迭代次数
            main_snap_iter_nums = 15,#主要snap迭代次数
            sampling_style = 'replay',#采样方式,默认为‘replay’
            sample_mult = 1#采样倍数,默认为1
            ):
        '''
            inputs:
                replay_buffer
                iterations episode_timesteps
        '''
        #######
        #在适应阶段开始时重置优化器
        #######
        self.actor_optimizer = optim.Adam(self.actor.parameters())
        self.critic_optimizer = optim.Adam(self.critic.parameters())

        #######
        # Adaptaion step:
        #学习一个模型来纠正协变量偏移
        #######
        out_single = None

        #训练csc模型
        csc_model, csc_info = self.train_cs(task_id = task_id,
                                            snap_buffer = eval_task_buffer,
                                            train_tasks_buffer = train_tasks_buffer,
                                            adaptation_step = True)

        #每个任务用TD3训练
        out_single = self.do_training(replay_buffer = eval_task_buffer.get_buffer(task_id),
                                      iterations = snap_iter_nums,
                                      csc_model = None,
                                      apply_prox = False,
                                      current_batch_size = eval_task_buffer.size_rb(task_id))
        #self.copy_model_params()

        #保存任务task_id的模型参数副本
        out_single['csc_info'] = csc_info
        out_single['snap_iter'] = snap_iter_nums

        # sampling_style is based on 'replay'
        #每个训练任务都有自己的buffer,因此从每个任务中进行采样
        out = self.do_training(replay_buffer = train_replay_buffer,
                                   iterations = main_snap_iter_nums,
                                   csc_model = csc_model,
                                   apply_prox = True,
                                   current_batch_size = sample_mult * self.batch_size)

        return out, out_single

    def rollback(self):
        '''
            This function rollback everything to state before test-adaptation
        '''

        ####### ####### ####### Super Important ####### ####### #######
        # It is very important to make sure that we rollback everything to
        # Step 0
        ####### ####### ####### ####### ####### ####### ####### #######
        self.actor.load_state_dict(self.actor_copy.state_dict())
        self.actor_target.load_state_dict(self.actor_target_copy.state_dict())
        self.critic.load_state_dict(self.critic_copy.state_dict())
        self.critic_target.load_state_dict(self.critic_target_copy.state_dict())
        self.actor_optimizer.load_state_dict(self.actor_optimizer_copy.state_dict())
        self.critic_optimizer.load_state_dict(self.critic_optimizer_copy.state_dict())

    def save_model_states(self):
        #保存模型状态
        ####### ####### ####### Super Important ####### ####### #######
        # Step 0: It is very important to make sure that we save model params before
        # do anything here
        ####### ####### ####### ####### ####### ####### ####### #######
        self.actor_copy = deepcopy(self.actor)
        self.actor_target_copy = deepcopy(self.actor_target)
        self.critic_copy = deepcopy(self.critic)
        self.critic_target_copy = deepcopy(self.critic_target)
        self.actor_optimizer_copy  = deepcopy(self.actor_optimizer)
        self.critic_optimizer_copy = deepcopy(self.critic_optimizer)

    def set_training_style(self):
        '''
            This function just selects style of training
        '''
        print('**** TD3 style is selected ****')
        self.training_func = self.train_TD3

    def train(self,
              replay_buffer = None,
              iterations = None,
              tasks_buffer = None,
              train_iter = 0,
              task_id = None,
              nums_snap_trains = 5):
        '''
         This starts type of desired training
        '''
        return self.training_func(  replay_buffer = replay_buffer,
                                    iterations = iterations,
                                    tasks_buffer = tasks_buffer,
                                    train_iter = train_iter,
                                    task_id = task_id,
                                    nums_snap_trains = nums_snap_trains
                                )

2.1.4multi_tasks_snapshot.py

import numpy as np
from algs.MQL.buffer import Buffer
import random

class MultiTasksSnapshot(object):
	def __init__(self, max_size=1e3):
		'''	
			all task will have same size
		'''
		self.max_size = max_size
		
	def init(self, task_ids=None):
		'''
			init buffers for all tasks
		'''
		self.task_buffers = dict([(idx, Buffer(max_size = self.max_size))
									for idx in task_ids
								])

	def reset(self, task_id):

		self.task_buffers[task_id].reset()

	def list(self):

		return list(self.task_buffers.keys())

	def add(self, task_id, data):
		'''
			data ==> (state, next_state, action, reward, done, previous_action, previous_reward)
		'''
		self.task_buffers[task_id].add(data)

	def size_rb(self, task_id):

		return self.task_buffers[task_id].size_rb()

	def get_buffer(self, task_id):

		return self.task_buffers[task_id]

	def sample(self, task_ids, batch_size):
		'''
			Returns tuples of (state, next_state, action, reward, done,
							  previous_action, previous_reward, previous_state
							  )
		'''
		if len(task_ids) == 1:
			xx, _, _, _, _, pu, pr, px, _, _, _ =  self.task_buffers[task_ids[0]].sample(batch_size)

			return pu, pr, px, xx

		mb_actions = []
		mb_rewards = []
		mb_obs = []
		mb_x = []

		for tid in task_ids:

			xx, _, _, _, _, pu, pr, px, _, _, _ = self.task_buffers[tid].sample(batch_size)
			mb_actions.append(pu) # batch_size x D1
			mb_rewards.append(pr) # batch_size x D2
			mb_obs.append(px)     # batch_size x D3
			mb_x.append(xx)

		mb_actions = np.asarray(mb_actions, dtype=np.float32) # task_ids x batch_size x D1
		mb_rewards = np.asarray(mb_rewards, dtype=np.float32) # task_ids x batch_size x D2
		mb_obs     = np.asarray(mb_obs, dtype=np.float32)     # task_ids x batch_size x D2
		mb_x       = np.asarray(mb_x, dtype=np.float32)

		return mb_actions, mb_rewards, mb_obs, mb_x

	def sample_tasks(self, task_ids, batch_size):
		'''
			Returns tuples of (state, next_state, action, reward, done,
							  previous_action, previous_reward, previous_state
							  )
		'''
		mb_xx = []
		mb_yy = []
		mb_u = []
		mb_r = []
		mb_d = []
		mb_pu = []
		mb_pr = []
		mb_px = []
		mb_nu = []
		mb_nr = []
		mb_nx = []

		# shuffle task lists
		shuffled_task_ids = random.sample(task_ids, len(task_ids))

		for tid in shuffled_task_ids:

			xx, yy, u, r, d, pu, pr, px, nu, nr, nx = self.task_buffers[tid].sample(batch_size)
			mb_xx.append(xx) # batch_size x D1
			mb_yy.append(yy) # batch_size x D2
			mb_u.append(u)   # batch_size x D3
			mb_r.append(r)
			mb_d.append(d)
			mb_pu.append(pu)
			mb_pr.append(pr)
			mb_px.append(px)
			mb_nu.append(nu)
			mb_nr.append(nr)
			mb_nx.append(nx)

		mb_xx = np.asarray(mb_xx, dtype=np.float32).reshape(len(task_ids) * batch_size , -1) # task_ids x batch_size x D1
		mb_yy = np.asarray(mb_yy, dtype=np.float32).reshape(len(task_ids) * batch_size , -1) # task_ids x batch_size x D1
		mb_u = np.asarray(mb_u, dtype=np.float32).reshape(len(task_ids) * batch_size , -1) # task_ids x batch_size x D1
		mb_r = np.asarray(mb_r, dtype=np.float32).reshape(len(task_ids) * batch_size , -1) # task_ids x batch_size x D1
		mb_d = np.asarray(mb_d, dtype=np.float32).reshape(len(task_ids) * batch_size , -1) # task_ids x batch_size x D1
		mb_pu = np.asarray(mb_pu, dtype=np.float32).reshape(len(task_ids) * batch_size , -1) # task_ids x batch_size x D1
		mb_pr = np.asarray(mb_pr, dtype=np.float32).reshape(len(task_ids) * batch_size , -1) # task_ids x batch_size x D1
		mb_px = np.asarray(mb_px, dtype=np.float32).reshape(len(task_ids) * batch_size , -1) # task_ids x batch_size x D1
		mb_nu = np.asarray(mb_nu, dtype=np.float32).reshape(len(task_ids) * batch_size , -1) # task_ids x batch_size x D1
		mb_nr = np.asarray(mb_nr, dtype=np.float32).reshape(len(task_ids) * batch_size , -1) # task_ids x batch_size x D1
		mb_nx = np.asarray(mb_nx, dtype=np.float32).reshape(len(task_ids) * batch_size , -1) # task_ids x batch_size x D1

		return mb_xx, mb_yy, mb_u, mb_r, mb_d, mb_pu, mb_pr, mb_px, mb_nu, mb_nr, mb_nx

3.configs

定义了不同任务的配置,比如任务数量,迭代次数等。

3.1abl_envs.json

3.2pearl_envs.json

4.misc

4.1__init__.py

和上一个__init__.py一样

4.2env_meta.py

创建PEARL环境‘ant-dir’、‘ant-goal’等。

4.3logger.py

import os
import sys
import shutil
import os.path as osp
import json
import time
import datetime
import tempfile
from collections import defaultdict

DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40

DISABLED = 50

class KVWriter(object):
    def writekvs(self, kvs):
        raise NotImplementedError
    #定义了writekvs方法,用于写入键值对数据
class SeqWriter(object):
    def writeseq(self, seq):
        raise NotImplementedError
    #定义了writeseq方法,用于写入序列数据

class HumanOutputFormat(KVWriter, SeqWriter):
    #实现了KVWriter和SeqWriter接口,以人类可读的格式将数据写入文件。它接受文件名或文件对象作为参数,并在构造函数中打开文件。它的writekvs方法将键值对数据以表格的形式写入文件,每行一个键值对。它的writeseq方法将序列数据写入文件,每个元素之间用空格分隔。
    def __init__(self, filename_or_file):
        if isinstance(filename_or_file, str):
            self.file = open(filename_or_file, 'wt')
            self.own_file = True
        else:
            assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s'%filename_or_file
            self.file = filename_or_file
            self.own_file = False

    def writekvs(self, kvs):
        # Create strings for printing
        key2str = {}
        for (key, val) in sorted(kvs.items()):
            if isinstance(val, float):
                valstr = '%-8.3g' % (val,)
            else:
                valstr = str(val)
            key2str[self._truncate(key)] = self._truncate(valstr)

        # Find max widths
        if len(key2str) == 0:
            print('WARNING: tried to write empty key-value dict')
            return
        else:
            keywidth = max(map(len, key2str.keys()))
            valwidth = max(map(len, key2str.values()))

        # Write out the data
        dashes = '-' * (keywidth + valwidth + 7)
        lines = [dashes]
        for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
            lines.append('| %s%s | %s%s |' % (
                key,
                ' ' * (keywidth - len(key)),
                val,
                ' ' * (valwidth - len(val)),
            ))
        lines.append(dashes)
        self.file.write('\n'.join(lines) + '\n')

        # Flush the output to the file
        self.file.flush()

    def _truncate(self, s):
        return s[:20] + '...' if len(s) > 23 else s

    def writeseq(self, seq):
        seq = list(seq)
        for (i, elem) in enumerate(seq):
            self.file.write(elem)
            if i < len(seq) - 1: # add space unless this is the last one
                self.file.write(' ')
        self.file.write('\n')
        self.file.flush()

    def close(self):
        if self.own_file:
            self.file.close()

class JSONOutputFormat(KVWriter):
    #实现了KVWriter接口,将数据以JSON格式写入文件。它接受一个文件名作为参数,并在构造函数中打开文件。它的writekvs方法将键值对数据转换为JSON格式,并写入文件。
    def __init__(self, filename):
        self.file = open(filename, 'wt')

    def writekvs(self, kvs):
        for k, v in sorted(kvs.items()):
            if hasattr(v, 'dtype'):
                v = v.tolist()
                kvs[k] = float(v)
        self.file.write(json.dumps(kvs) + '\n')
        self.file.flush()

    def close(self):
        self.file.close()

class CSVOutputFormat(KVWriter):
    #实现了KVWriter接口,将数据以CSV格式写入文件。它接受一个文件名作为参数,并在构造函数中打开文件。它的writekvs方法将键值对数据以CSV格式写入文件,每行一个键值对。
    def __init__(self, filename):
        self.file = open(filename, 'w+t')
        self.keys = []
        self.sep = ','

    def writekvs(self, kvs):
        # Add our current row to the history
        extra_keys = list(kvs.keys() - self.keys)
        extra_keys.sort()
        if extra_keys:
            self.keys.extend(extra_keys)
            self.file.seek(0)
            lines = self.file.readlines()
            self.file.seek(0)
            for (i, k) in enumerate(self.keys):
                if i > 0:
                    self.file.write(',')
                self.file.write(k)
            self.file.write('\n')
            for line in lines[1:]:
                self.file.write(line[:-1])
                self.file.write(self.sep * len(extra_keys))
                self.file.write('\n')
        for (i, k) in enumerate(self.keys):
            if i > 0:
                self.file.write(',')
            v = kvs.get(k)
            if v is not None:
                self.file.write(str(v))
        self.file.write('\n')
        self.file.flush()

    def close(self):
        self.file.close()


class TensorBoardOutputFormat(KVWriter):
    """
    Dumps key/value pairs into TensorBoard's numeric format.
    实现了KVWriter接口,将数据以TensorBoard的数值格式写入文件。它接受一个目录名作为参数,并在构造函数中创建目录。它的writekvs方法将键值对数据转换为TensorBoard的数值格式,并写入文件。
    """
    def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = 'events'
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat
        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))

    def writekvs(self, kvs):
        def summary_val(k, v):
            kwargs = {'tag': k, 'simple_value': float(v)}
            return self.tf.Summary.Value(**kwargs)
        summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
        event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
        event.step = self.step # is there any reason why you'd want to specify the step?
        self.writer.WriteEvent(event)
        self.writer.Flush()
        self.step += 1

    def close(self):
        if self.writer:
            self.writer.Close()
            self.writer = None

def make_output_format(format, ev_dir, log_suffix=''):
    #根据指定的格式和日志目录创建相应的输出格式对象。支持的格式包括stdout(标准输出)、log(文本文件)、json(JSON文件)、csv(CSV文件)和tensorboard(TensorBoard文件)。
    os.makedirs(ev_dir, exist_ok=True)
    if format == 'stdout':
        return HumanOutputFormat(sys.stdout)
    elif format == 'log':
        return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix))
    elif format == 'json':
        return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix))
    elif format == 'csv':
        return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix))
    elif format == 'tensorboard':
        return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix))
    else:
        raise ValueError('Unknown format specified: %s' % (format,))

# ================================================================
# API
# ================================================================

def logkv(key, val):
    """
    用于记录某个诊断指标的值。在每次迭代中针对每个诊断指标调用一次。如果多次调用该函数,将使用最后一次调用的值。
    Log a value of some diagnostic
    Call this once for each diagnostic quantity, each iteration
    If called many times, last value will be used.
    """
    Logger.CURRENT.logkv(key, val)

def logkv_mean(key, val):
    """
    The same as logkv(), but if called many times, values averaged.
    """
    Logger.CURRENT.logkv_mean(key, val)

def logkvs(d):
    """
    用于记录一组键值对数据
    Log a dictionary of key-value pairs
    """
    for (k, v) in d.items():
        logkv(k, v)

def dumpkvs():
    """
    将当前迭代的所有诊断指标写入日志文件
    Write all of the diagnostics from the current iteration

    level: int. (see logger.py docs) If the global logger level is higher than
                the level argument here, don't print to stdout.
    """
    Logger.CURRENT.dumpkvs()

def getkvs():
    #返回当前迭代的所有诊断指标
    return Logger.CURRENT.name2val


def log(*args, level=INFO):
    """
    用于将一系列参数写入控制台和输出文件
    Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
    """
    Logger.CURRENT.log(*args, level=level)

def debug(*args):
    log(*args, level=DEBUG)

def info(*args):
    log(*args, level=INFO)

def warn(*args):
    log(*args, level=WARN)

def error(*args):
    log(*args, level=ERROR)


def set_level(level):
    """
    Set logging threshold on current logger.
    """
    Logger.CURRENT.set_level(level)

def get_dir():
    """
    用于获取日志文件的目录。如果没有设置输出目录(即没有调用start函数),则返回None
    Get directory that log files are being written to.
    will be None if there is no output directory (i.e., if you didn't call start)
    """
    return Logger.CURRENT.get_dir()

record_tabular = logkv
dump_tabular = dumpkvs

class ProfileKV:
    """
    一个上下文管理器类,用于记录代码块的执行时间
    Usage:
    with logger.ProfileKV("interesting_scope"):
        code
    """
    def __init__(self, n):
        self.n = "wait_" + n
    def __enter__(self):
        self.t1 = time.time()
    def __exit__(self ,type, value, traceback):
        Logger.CURRENT.name2val[self.n] += time.time() - self.t1

def profile(n):
    """
    用于将函数的执行时间记录为诊断指标。
    Usage:
    @profile("my_func")
    def my_func(): code
    """
    def decorator_with_name(func):
        def func_wrapper(*args, **kwargs):
            with ProfileKV(n):
                return func(*args, **kwargs)
        return func_wrapper
    return decorator_with_name


# ================================================================
# Backend
# ================================================================

class Logger(object):
    #是日志记录器的实现,包括记录诊断指标的值和次数;根据日志级别输出日志消息;设置日志级别;获取日志文件的目录;关闭日志记录器。
    DEFAULT = None  # A logger with no output files. (See right below class definition)
                    # So that you can still log to the terminal without setting up any output files
    CURRENT = None  # Current logger being used by the free functions above

    def __init__(self, dir, output_formats):
        self.name2val = defaultdict(float)  # values this iteration
        self.name2cnt = defaultdict(int)
        self.level = INFO
        self.dir = dir
        self.output_formats = output_formats

    # Logging API, forwarded
    # ----------------------------------------
    def logkv(self, key, val):
        self.name2val[key] = val

    def logkv_mean(self, key, val):
        if val is None:
            self.name2val[key] = None
            return
        oldval, cnt = self.name2val[key], self.name2cnt[key]
        self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
        self.name2cnt[key] = cnt + 1

    def dumpkvs(self):
        if self.level == DISABLED: return
        for fmt in self.output_formats:
            if isinstance(fmt, KVWriter):
                fmt.writekvs(self.name2val)
        self.name2val.clear()
        self.name2cnt.clear()

    def log(self, *args, level=INFO):
        if self.level <= level:
            self._do_log(args)

    # Configuration
    # ----------------------------------------
    def set_level(self, level):
        self.level = level

    def get_dir(self):
        return self.dir

    def close(self):
        for fmt in self.output_formats:
            fmt.close()

    # Misc
    # ----------------------------------------
    def _do_log(self, args):
        for fmt in self.output_formats:
            if isinstance(fmt, SeqWriter):
                fmt.writeseq(map(str, args))

def configure(dir=None, format_strs=None):
    #它接受一个目录参数和一个格式字符串列表参数,用于指定输出目录和输出格式。
    if dir is None:
        dir = os.getenv('OPENAI_LOGDIR')
    if dir is None:
        dir = osp.join(tempfile.gettempdir(),
            datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"))
    assert isinstance(dir, str)
    os.makedirs(dir, exist_ok=True)

    log_suffix = ''
    rank = 0
    # check environment variables here instead of importing mpi4py
    # to avoid calling MPI_Init() when this module is imported
    for varname in ['PMI_RANK', 'OMPI_COMM_WORLD_RANK']:
        if varname in os.environ:
            rank = int(os.environ[varname])
    if rank > 0:
        log_suffix = "-rank%03i" % rank

    if format_strs is None:
        if rank == 0:
            format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',')
        else:
            format_strs = os.getenv('OPENAI_LOG_FORMAT_MPI', 'log').split(',')
    format_strs = filter(None, format_strs)
    output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]

    Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
    log('Logging to %s'%dir)

def _configure_default_logger():
    #用于配置默认的日志记录器,将日志输出到标准输出。
    format_strs = None
    # keep the old default of only writing to stdout
    if 'OPENAI_LOG_FORMAT' not in os.environ:
        format_strs = ['stdout']
    configure(format_strs=format_strs)
    Logger.DEFAULT = Logger.CURRENT

def reset():
    #用于重置当前日志记录器为默认日志记录器
    if Logger.CURRENT is not Logger.DEFAULT:
        Logger.CURRENT.close()
        Logger.CURRENT = Logger.DEFAULT
        log('Reset logger')

class scoped_configure(object):
    #用于在特定作用域内临时配置日志记录器。它接受一个目录参数和一个格式字符串列表参数,并在进入作用域时配置日志记录器,退出作用域时恢复先前的日志记录器设置。
    def __init__(self, dir=None, format_strs=None):
        self.dir = dir
        self.format_strs = format_strs
        self.prevlogger = None
    def __enter__(self):
        self.prevlogger = Logger.CURRENT
        configure(dir=self.dir, format_strs=self.format_strs)
    def __exit__(self, *args):
        Logger.CURRENT.close()
        Logger.CURRENT = self.prevlogger

# ================================================================
# Readers
# ================================================================

def read_json(fname):
    #用于读取JSON格式的日志文件,并返回一个Pandas DataFrame对象。
    import pandas
    ds = []
    with open(fname, 'rt') as fh:
        for line in fh:
            ds.append(json.loads(line))
    return pandas.DataFrame(ds)

def read_csv(fname):
    #用于读取CSV格式的日志文件,并返回一个Pandas DataFrame对象。
    import pandas
    return pandas.read_csv(fname, index_col=None, comment='#')

def read_tb(path):
    """
    用于读取TensorBoard格式的日志文件或目录,并返回一个Pandas DataFrame对象。它支持读取多个TensorBoard文件,并将数据合并到一个DataFrame中。
    path : a tensorboard file OR a directory, where we will find all TB files
           of the form events.*
    """
    import pandas
    import numpy as np
    from glob import glob
    from collections import defaultdict
    import tensorflow as tf
    if osp.isdir(path):
        fnames = glob(osp.join(path, "events.*"))
    elif osp.basename(path).startswith("events."):
        fnames = [path]
    else:
        raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s"%path)
    tag2pairs = defaultdict(list)
    maxstep = 0
    for fname in fnames:
        for summary in tf.train.summary_iterator(fname):
            if summary.step > 0:
                for v in summary.summary.value:
                    pair = (summary.step, v.simple_value)
                    tag2pairs[v.tag].append(pair)
                maxstep = max(summary.step, maxstep)
    data = np.empty((maxstep, len(tag2pairs)))
    data[:] = np.nan
    tags = sorted(tag2pairs.keys())
    for (colidx,tag) in enumerate(tags):
        pairs = tag2pairs[tag]
        for (step, value) in pairs:
            data[step-1, colidx] = value
    return pandas.DataFrame(data, columns=tags)

# configure the default logger on import,模块导入时配置默认的日志记录器。
_configure_default_logger()

if __name__ == "__main__":
    _demo()

4.4runner_meta_offpolicy.py

import numpy as np
import torch
from collections import deque
#目标是通过迭代地从环境中采集经验,将它们存储在经验回放缓冲区中,然后从缓冲区中采样经验批次来更新模型,从而学习一个能够最大化预期累积奖励的策略。
class Runner:
    """
      This class generates batches of experiences
    """
    def __init__(self,
                 env,
                 model,
                 replay_buffer=None,#经验回放缓冲区
                 burn_in=1e4,#在开始学习之前,先随机行动的步数
                 expl_noise=0.1,#给动作添加噪声,增加噪声
                 total_timesteps = 1e6,#总时间步(agent在环境中行动的总步数)
                 max_path_length = 200,#最大路径长度(每个episode的最大步数)
                 history_length = 1,#历史长度(选择行动时,需要考虑的过去的步数)
                 device = 'cpu'):#设备(cpu or gpu)
        '''
            nsteps: number of steps 
        '''
        self.model = model
        self.env = env
        self.burn_in = burn_in
        self.device = device
        self.episode_rewards = deque(maxlen=10)#创建一个双端队列,用于储存最近10个episode的奖励
        self.episode_lens = deque(maxlen=10)#创建一个双端队列,用于储存最近10个episode的长度
        self.replay_buffer = replay_buffer
        self.expl_noise = expl_noise#探索噪声,用于在选择行动时添加随机噪声
        self.total_timesteps = total_timesteps
        self.max_path_length = max_path_length
        self.hist_len = history_length

    def run(self, update_iter, keep_burning = False, task_id = None, early_leave = 200):
        '''
            将transition添加到replay buffer
            Early_leave用于冷启动时收集更多来自不同任务的数据,而不是只关注少量任务
            This function add transition to replay buffer.
            Early_leave is used in just cold start to collect more data from various tasks,
            rather than focus on just few ones
        '''
        obs = self.env.reset()# 重置环境,并获取初始观察值
        done = False # 初始化done标志为False,表示episode尚未结束
        episode_timesteps = 0# 初始化episode的时间步数为0
        episode_reward = 0# 初始化episode的奖励为0
        uiter = 0# 初始化迭代次数为0
        reward_epinfos = []# 初始化奖励信息列表为空

        ########
        ## create a queue to keep track of past rewards and actions
        ########
        rewards_hist = deque(maxlen=self.hist_len) # 创建一个双端队列,用于存储过去的奖励
        actions_hist = deque(maxlen=self.hist_len) # 创建一个双端队列,用于存储过去的动作
        obsvs_hist   = deque(maxlen=self.hist_len) # 创建一个双端队列,用于存储过去的观测

        next_hrews = deque(maxlen=self.hist_len) # 创建一个双端队列,用于存储下一个奖励
        next_hacts = deque(maxlen=self.hist_len) # 创建一个双端队列,用于存储下一个动作
        next_hobvs = deque(maxlen=self.hist_len) # 创建一个双端队列,用于存储下一个观测

        # Given batching schema, I need to build a full seq to keep in replay buffer
        # Add to all zeros.
        zero_action = np.zeros(self.env.action_space.shape[0])#创建全为0的行动
        zero_obs    = np.zeros(obs.shape)#创建全为0的观测
        for _ in range(self.hist_len):#对于历史每一步
            rewards_hist.append(0)#在历史奖励里添加0
            actions_hist.append(zero_action.copy())#在历史行动中添加0
            obsvs_hist.append(zero_obs.copy())#在历史观测中添加0

            # same thing for next_h*
            next_hrews.append(0)#在下个奖励里添加0
            next_hacts.append(zero_action.copy())#在下个动作里添加0
            next_hobvs.append(zero_obs.copy())#在下个观测里添加0

        # now add obs to the seq
        rand_acttion = np.random.normal(0, self.expl_noise, size=self.env.action_space.shape[0])#生成一个随机行动
        rand_acttion = rand_acttion.clip(self.env.action_space.low, self.env.action_space.high)#将随机行动剪裁到行动空间的范围内
        rewards_hist.append(0)#在历史奖励里添加0
        actions_hist.append(rand_acttion.copy())#在历史行动中添加这个随机行动
        obsvs_hist.append(obs.copy())#在历史观测值中添加当前观测值

        ######
        # Start collecting data
        #####
        while not done and uiter < np.minimum(self.max_path_length, early_leave):
            #当episode未结束且迭代次数小于最大路径长度和early_leave步数的最小值时,进入循环
            #####
            # Convert actions_hist, rewards_hist to np.array and flatten them out
            # for example: hist =7, actin_dim = 11 --> np.asarray(actions_hist(7, 11)) ==> flatten ==> (77,)
            np_pre_actions = np.asarray(actions_hist, dtype=np.float32).flatten()  #(hist, action_dim) => (hist *action_dim,)
            np_pre_rewards = np.asarray(rewards_hist, dtype=np.float32) #(hist, )
            np_pre_obsers = np.asarray(obsvs_hist, dtype=np.float32).flatten()  #(hist, action_dim) => (hist *action_dim,)

            # Select action randomly or according to policy
            if keep_burning or update_iter < self.burn_in:# 如果保持预热或更新迭代次数小于预热步数
                action = self.env.action_space.sample()#随机选择行动

            else:
                # select_action take into account previous action to take into account
                # previous action in selecting a new action
                action = self.model.select_action(np.array(obs), np.array(np_pre_actions), np.array(np_pre_rewards), np.array(np_pre_obsers))#根据模型选择一个行动

                if self.expl_noise != 0: #如果探索噪声不为0
                    action = action + np.random.normal(0, self.expl_noise, size=self.env.action_space.shape[0])#在行动上添加随机噪声
                    action = action.clip(self.env.action_space.low, self.env.action_space.high)#将行动控制到行动空间范围内

            # Perform action
            new_obs, reward, done, _ = self.env.step(action) #在环境中执行动作,并获取新的观测、奖励和done
            if episode_timesteps + 1 == self.max_path_length:#如果episode的时间步+1=最大路径长度
                done_bool = 0#设置done=0

            else:
                done_bool = float(done)#否则将done转换为浮点数

            episode_reward += reward#添加奖励到episode的奖励上
            reward_epinfos.append(reward)#在奖励信息列表中添加奖励

            ###############
            next_hrews.append(reward)#添加下一步奖励、行动、观测
            next_hacts.append(action.copy())
            next_hobvs.append(obs.copy())

            # np_next_hacts and np_next_hrews are required for TD3 alg
            np_next_hacts = np.asarray(next_hacts, dtype=np.float32).flatten()  #(hist, action_dim) => (hist *action_dim,)
            np_next_hrews = np.asarray(next_hrews, dtype=np.float32) #(hist, )
            np_next_hobvs = np.asarray(next_hobvs, dtype=np.float32).flatten() #(hist, )

            # Store data in replay buffer
            self.replay_buffer.add((obs, new_obs, action, reward, done_bool,
                                    np_pre_actions, np_pre_rewards, np_pre_obsers,
                                    np_next_hacts, np_next_hrews, np_next_hobvs))

            # new becomes old
            rewards_hist.append(reward)
            actions_hist.append(action.copy())
            obsvs_hist.append(obs.copy())
            obs = new_obs.copy()#将新的观测复制给当前的观测
            episode_timesteps += 1
            update_iter += 1
            uiter += 1

        info = {}#初始化信息字典为空
        info['episode_timesteps'] = episode_timesteps#添加episode的时间步
        info['update_iter'] = update_iter#添加更新迭代次数
        info['episode_reward'] = episode_reward#添加奖励
        info['epinfos'] = [{"r": round(sum(reward_epinfos), 6), "l": len(reward_epinfos)}]#添加奖励总和和步数

        return info

4.5runner_multi_snapshot.py

比4.4多了一个tasks_buffer

# This is snapshot buffer which has short memeory
            self.tasks_buffer.add(task_id, (obs, new_obs, action, reward, done_bool,
                                    np_pre_actions, np_pre_rewards, np_pre_obsers,
                                    np_next_hacts, np_next_hrews, np_next_hobvs))

4.6torch_utility.py

from __future__ import  print_function, division
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
#主要是用于模型的保存和加载,以及学习率的调整
class DictToObj:
    #用于将字典转换为对象,使得可以用点操作符访问字典中的元素
    def __init__(self, **entries):
        self.__dict__.update(entries)

def get_state(m):
    #返回模型状态
    if m is None:
        return None
    state = {}
    for k, v in m.state_dict().items():
        state[k] = v.clone()#复制每个状态
    return state

def load_model_states(path):
    #加载之前训练好的模型
    checkpoint = torch.load(path, map_location='cpu')#加载模型
    m_states = checkpoint['model_states']#获取模型状态
    m_params = checkpoint['args']#获取模型参数
    if 'env_ob_rms' in checkpoint:
        env_ob_rms = checkpoint['env_ob_rms']#获取环境观测
    else:
        env_ob_rms = None

    return m_states, DictToObj(**m_params), env_ob_rms#返回状态参数和观测值

def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
    #线性地减小学习率,有助于模型在后期更稳定的收敛
    lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
    for param_group in optimizer.param_groups:#这个循环遍历优化器中的所有参数组,并将每个参数组的学习率设置为上面计算出的新学习率。
        param_group['lr'] = lr

4.7utils.py

import os
import numpy as np
import gym
import glob
import json
from collections import deque, OrderedDict
import psutil
import re
import csv
import pandas as pd
import ntpath
import re
import random
#于处理文件和目录、设置随机种子、获取动作空间信息、写入和读取json文件、写入CSV文件以及安全地计算平均值等操作。
def set_global_seeds(myseed):
    #设置全局随机种子,以保证实验的可重复性
    import torch
    torch.manual_seed(myseed)
    np.random.seed(myseed)
    random.seed(myseed)

def get_fname_from_path(f):
    '''
     input:
           '/Users/user/logs/check_points/mmmxm_dummy_B32_H5_D1_best.pt'
     output:
           'mmmxm_dummy_B32_H5_D1_best.pt'
    '''
    #从路径中获取文件名
    return ntpath.basename(f)

def get_action_info(action_space, obs_space = None):
    '''
        This fucntion returns info about type of actions.
    '''
    #获取动作空间的信息
    space_type = action_space.__class__.__name__

    if action_space.__class__.__name__ == "Discrete":
            num_actions = action_space.n

    elif action_space.__class__.__name__ == "Box":
            num_actions = action_space.shape[0]

    elif action_space.__class__.__name__ == "MultiBinary":
            num_actions = action_space.shape[0]
    
    else:
        raise NotImplementedError
    
    return num_actions, space_type

def create_dir(log_dir, ext = '*.monitor.csv', cleanup = False):
    '''
        Setup checkpoints dir
    '''
    #创建目录
    try:
        os.makedirs(log_dir)

    except OSError:
        if cleanup == True:
            files = glob.glob(os.path.join(log_dir, '*.'))

            for f in files:
                os.remove(f)

def dump_to_json(path, data):
    '''
      Write json file
    '''
    #将数据写入json文件
    with open(path, 'w') as f:
        json.dump(data, f)

def read_json(input_json):
    #加载json文件
    file_info = json.load(open(input_json, 'r'))

    return file_info

class CSVWriter:
    #用于写入CSV
    def __init__(self, fname, fieldnames):

        self.fname = fname
        self.fieldnames = fieldnames
        self.csv_file = open(fname, mode='w')
        self.writer = None

    def write(self, data_stats):
        #写入数据
        if self.writer == None:
            self.writer = csv.DictWriter(self.csv_file, fieldnames=self.fieldnames)
            self.writer.writeheader()

        self.writer.writerow(data_stats)
        self.csv_file.flush()

    def close(self):
        #关闭文件
        self.csv_file.close()

def safemean(xs):
    #在计算平均值时避免除数为0的错误
    '''
        Avoid division error when calculate the mean (in our case if
        epinfo is empty returns np.nan, not return an error)
    '''
    return np.nan if len(xs) == 0 else np.mean(xs)

5.models

5.1__init__.py

和之前的__init__.py一样

5.2networks.py

from __future__ import  print_function, division
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from misc.utils import get_action_info
#定义了TD3算法的Actor-Critic网络,和用于处理上下文信息的context网络
##################################
# Actor and Critic Newtork for TD3
##################################
class Actor(nn.Module):
    """
      This arch is standard based on https://github.com/sfujim/TD3/blob/master/TD3.py
    """
    def __init__(self,
                action_space,#动作空间
                hidden_sizes = [400, 300],#隐藏层大小
                input_dim = None,#输入维度
                hidden_activation = F.relu,#隐藏层激活函数
                max_action= None,#最大动作值
                enable_context = False,#是否使用上下文信息
                hiddens_dim_conext = [50],#上下文的隐藏层大小
                input_dim_context=None,#上下文输入维度
                output_conext=None,#上下文输出维度
                only_concat_context = 0,#是否只拼接上下文信息
                history_length = 1,#历史长度
                obsr_dim = None,#观测维度
                device = 'cpu'#设备选择
                ):

        super(Actor, self).__init__()
        self.hsize_1 = hidden_sizes[0]#第一个隐藏层大小
        self.hsize_2 = hidden_sizes[1]#第二个隐藏层大小
        action_dim, action_space_type = get_action_info(action_space)#获取动作维度和类型

        #定义actor网络结构
        self.actor = nn.Sequential(
                        nn.Linear(input_dim[0], self.hsize_1),#输入层到第一个隐藏层
                        nn.ReLU(),#第一个隐藏层激活函数
                        nn.Linear(self.hsize_1, self.hsize_2),#第一个隐藏层到第二个隐藏层
                        nn.ReLU()#第二个隐藏层激活函数
                        )
        self.out = nn.Linear(self.hsize_2,  action_dim)#第二个隐藏层到输出层
        self.max_action = max_action#最大动作值
        self.enable_context = enable_context#是否使用上下文信息
        self.output_conext = output_conext#上下文输出维度

        #context network
        self.context = None
        if self.enable_context == True:#如果使用了上下文信息,则初始化上下文网络
            self.context = Context(hidden_sizes=hiddens_dim_conext,
                                   input_dim=input_dim_context,
                                   output_dim = output_conext,
                                   only_concat_context = only_concat_context,
                                   history_length = history_length,
                                   action_dim = action_dim,
                                   obsr_dim = obsr_dim,
                                   device = device
                                   )

    def forward(self, x, pre_act_rew = None, state = None, ret_context = False):
        #定义actor网络的前向传播
        '''
            input (x  : B * D where B is batch size and D is input_dim
            pre_act_rew: B * (A + 1) where B is batch size and A + 1 is input_dim
        '''
        combined = None
        if self.enable_context == True:
            #获取上下文特征
            combined = self.context(pre_act_rew)
            x = torch.cat([x, combined], dim = -1)#将上下文特征与状态特征拼接

        x = self.actor(x)#拼接后的信息输入给actor网络
        x = self.max_action * torch.tanh(self.out(x))#输出动作值,经过tanh激活并缩放到最大动作值范围内

        if ret_context == True:
            return x, combined#如果需要返回上下文特征则一起返回

        else:
            return x#否则只返回动作值

    def get_conext_feats(self, pre_act_rew):
        #获取上下文特征的函数
        '''
            pre_act_rew: B * (A + 1) where B is batch size and A + 1 is input_dim
            return combine features
        '''
        combined = self.context(pre_act_rew)

        return combined

class Critic(nn.Module):#定义critic网络
    """
      This arch is standard based on https://github.com/sfujim/TD3/blob/master/TD3.py
    """
    def __init__(self,
                action_space,
                hidden_sizes = [400, 300],
                input_dim = None,
                hidden_activation = F.relu,
                enable_context = False,
                dim_others = 0,#额外维度大小
                hiddens_dim_conext = [50],
                input_dim_context=None,
                output_conext=None,
                only_concat_context = 0,
                history_length = 1,
                obsr_dim = None,
                device = 'cpu'
                ):

        super(Critic, self).__init__()
        self.hsize_1 = hidden_sizes[0]
        self.hsize_2 = hidden_sizes[1]
        action_dim, action_space_type = get_action_info(action_space)

        # handling extra dim
        self.enable_context = enable_context

        if self.enable_context == True:#如果使用上下文信息,则增加额外的维度
            self.extra_dim = dim_others # right now, we add reward + previous action

        else:
            self.extra_dim = 0#否则额外维度为0

        # It uses two different Q networks,定义两个Q网络
        # Q1 architecture
        self.q1 = nn.Sequential(
                        nn.Linear(input_dim[0] + action_dim + self.extra_dim, self.hsize_1),#输入层到第一个隐藏层
                        nn.ReLU(),
                        nn.Linear(self.hsize_1, self.hsize_2),
                        nn.ReLU(),
                        nn.Linear(self.hsize_2, 1),
                        )


        # Q2 architecture
        self.q2 = nn.Sequential(
                        nn.Linear(input_dim[0] + action_dim + self.extra_dim, self.hsize_1),
                        nn.ReLU(),
                        nn.Linear(self.hsize_1, self.hsize_2),
                        nn.ReLU(),
                        nn.Linear(self.hsize_2, 1),
                        )

        if self.enable_context == True:
            self.context = Context(hidden_sizes=hiddens_dim_conext,
                                   input_dim=input_dim_context,
                                   output_dim = output_conext,
                                   only_concat_context = only_concat_context,
                                   history_length = history_length,
                                   action_dim = action_dim,
                                   obsr_dim = obsr_dim,
                                   device = device
                                   )

    def forward(self, x, u, pre_act_rew = None, ret_context = False):
        #定义critic的前向传播
        '''
            input (x): B * D where B is batch size and D is input_dim
            input (u): B * A where B is batch size and A is action_dim
            pre_act_rew: B * (A + 1) where B is batch size and A + 1 is input_dim
        '''
        xu = torch.cat([x, u], 1)#将状态和动作特征拼接
        combined = None

        if self.enable_context == True:
            combined = self.context(pre_act_rew)#获取上下文特征
            xu = torch.cat([xu, combined], dim = -1)#将上下文特征与状态、动作特征拼接

        # Q1的输出d
        x1 = self.q1(xu)
        # Q2的输出
        x2 = self.q2(xu)

        if ret_context == True:
            return x1, x2, combined #如果需要返回上下文特征则一起返回

        else:
            return x1, x2 #否则只返回两个Q值

    def Q1(self, x, u, pre_act_rew = None, ret_context = False):
        #定义获取Q1值的函数
        '''
            input (x): B * D where B is batch size and D is input_dim
            input (u): B * A where B is batch size and A is action_dim
            pre_act_rew: B * (A + 1) where B is batch size and A + 1 is input_dim
        '''

        xu = torch.cat([x, u], 1)
        combined = None

        if self.enable_context == True:
            combined = self.context(pre_act_rew)
            xu = torch.cat([xu, combined], dim = -1)

        # Q1
        x1 = self.q1(xu)

        if ret_context == True:
            return x1, combined

        else:
            return x1

    def get_conext_feats(self, pre_act_rew):
        #获取上下文特征
        '''
            pre_act_rew: B * (A + 1) where B is batch size and A + 1 is input_dim
            return combine features
        '''
        combined = self.context(pre_act_rew)

        return combined#返回上下文特征

class Context(nn.Module):
    #定义处理上下文的context网络
    """
      This layer just does non-linear transformation(s)
    """
    def __init__(self,
                 hidden_sizes = [50],
                 output_dim = None,
                 input_dim = None,
                 only_concat_context = 0,
                 hidden_activation=F.relu,
                 history_length = 1,
                 action_dim = None,
                 obsr_dim = None,
                 device = 'cpu'
                 ):

        super(Context, self).__init__()
        self.only_concat_context = only_concat_context
        self.hid_act = hidden_activation
        self.fcs = [] # 线性层列表
        self.hidden_sizes = hidden_sizes
        self.input_dim = input_dim
        self.output_dim_final = output_dim # count the fact that there is a skip connection,最终输出维度
        self.output_dim_last_layer  = output_dim // 2#最后一层的输出维度
        self.hist_length = history_length
        self.device = device
        self.action_dim = action_dim
        self.obsr_dim = obsr_dim

        #### build LSTM or multi-layers FF
        if only_concat_context == 3:
            # use LSTM or GRU
            self.recurrent =nn.GRU(self.input_dim,
                               self.hidden_sizes[0],
                               bidirectional = False,#指定GPU是否为双向的
                               batch_first = True,#指定输入数据的维度顺序是(batch_size,sequence_length,input_dim)
                               num_layers = 1)#GRU模型的层数

    def init_recurrent(self, bsize = None):
        #初始化循环网络的隐藏状态
        '''
            init hidden states
            Batch size can't be none
        '''
        # The order is (num_layers, minibatch_size, hidden_dim)
        # LSTM ==> return (torch.zeros(1, bsize, self.hidden_sizes[0]),
        #        torch.zeros(1, bsize, self.hidden_sizes[0]))
        return torch.zeros(1, bsize, self.hidden_sizes[0]).to(self.device)

    def forward(self, data):
        '''
            pre_x : B * D where B is batch size and D is input_dim
            pre_a : B * A where B is batch size and A is input_dim
            previous_reward: B * 1 where B is batch size and 1 is input_dim
        '''
        previous_action, previous_reward, pre_x = data[0], data[1], data[2]
        
        if self.only_concat_context == 3:
            # first prepare data for LSTM
            bsize, dim = previous_action.shape # previous_action is B* (history_len * D),获取批次大小和维度
            pacts = previous_action.view(bsize, -1, self.action_dim) # view(bsize, self.hist_length, -1),将动作数据重塑为适合LSTM的格式
            prews = previous_reward.view(bsize, -1, 1) # reward dim is 1, view(bsize, self.hist_length, 1)
            pxs   = pre_x.view(bsize, -1, self.obsr_dim ) # view(bsize, self.hist_length, -1)
            pre_act_rew = torch.cat([pacts, prews, pxs], dim = -1) # input to LSTM is [action, reward],拼接动作奖励和状态

            # init lstm/gru
            hidden = self.init_recurrent(bsize=bsize)#初始化LSTM的隐藏状态,并将批次大小传递给它

            # lstm/gru
            _, hidden = self.recurrent(pre_act_rew, hidden) # hidden is (1, B, hidden_size)
            out = hidden.squeeze(0) # (1, B, hidden_size) ==> (B, hidden_size)压缩out维度

            return out

        else:
            raise NotImplementedError

        return None

6.rand_param_envs

6.1__init__.py

from rand_param_envs.base import MetaEnv
from rand_param_envs.gym.envs.registration import register
#使用gym库的register来注册三个新的环境
register(
    id='Walker2DRandParams-v0',#二维行走模型的环境
    entry_point='rand_param_envs.walker2d_rand_params:Walker2DRandParamsEnv',#指定了环境类的位置,:前是文件路径,后是环境类的名称
)

register(
    id='HopperRandParams-v0',#跳跃模型的环境
    entry_point='rand_param_envs.hopper_rand_params:HopperRandParamsEnv',
)

register(
    id='PR2Env-v0',#PR2机器人的环境
    entry_point='rand_param_envs.pr2_env_reach:PR2Env',
)
#注册环境后,可以通过gym.make('环境ID')来创建环境的实例

6.2base.py

from rand_param_envs.gym.core import Env
from rand_param_envs.gym.envs.mujoco import MujocoEnv
import numpy as np


class MetaEnv(Env):
    #定义了元学习环境应具有的基本接口,包括step、sample_tasks、set_task、get_task和log_diagnostics
    def step(self, *args, **kwargs):
        return self._step(*args, **kwargs)

    def sample_tasks(self, n_tasks):
        """
        Samples task of the meta-environment

        Args:
            n_tasks (int) : number of different meta-tasks needed

        Returns:
            tasks (list) : an (n_tasks) length list of tasks
        """
        raise NotImplementedError

    def set_task(self, task):
        """
        Sets the specified task to the current environment

        Args:
            task: task of the meta-learning environment
        """
        raise NotImplementedError

    def get_task(self):
        """
        Gets the task that the agent is performing in the current environment

        Returns:
            task: task of the meta-learning environment
        """
        raise NotImplementedError

    def log_diagnostics(self, paths, prefix):
        """
        Logs env-specific diagnostic information

        Args:
            paths (list) : list of all paths collected with this env during this iteration
            prefix (str) : prefix for logger
        """
        pass

class RandomEnv(MetaEnv, MujocoEnv):
    #定义随机环境类,继承自元学习环境基类和Mujoco环境基类
    """
    This class provides functionality for randomizing the physical parameters of a mujoco model
    The following parameters are changed:
        - body_mass
        - body_inertia
        - damping coeff at the joints
    """
    RAND_PARAMS = ['body_mass', 'dof_damping', 'body_inertia', 'geom_friction']#随机物理量:重量、阻尼、惯性、几何摩擦力
    RAND_PARAMS_EXTENDED = RAND_PARAMS + ['geom_size']#扩展的随机参数列表

    def __init__(self, log_scale_limit, file_name, *args, rand_params=RAND_PARAMS, **kwargs):
        MujocoEnv.__init__(self, file_name, 4)#调用父类的初始化方法
        assert set(rand_params) <= set(self.RAND_PARAMS_EXTENDED), \
            "rand_params must be a subset of " + str(self.RAND_PARAMS_EXTENDED)#检查随机参数是否在允许的范围内
        self.log_scale_limit = log_scale_limit  #随即参数的对数大小限制
        self.rand_params = rand_params
        self.save_parameters()

    def sample_tasks(self, n_tasks):
        #为mujoco环境生成随机化参数集
        """
        Generates randomized parameter sets for the mujoco env

        Args:
            n_tasks (int) : number of different meta-tasks needed

        Returns:
            tasks (list) : an (n_tasks) length list of tasks
        """
        param_sets = []#初始化参数集列表

        for _ in range(n_tasks):#对每一个任务
            # body mass -> one multiplier for all body parts

            new_params = {}#初始化新的参数字典
            #随机化各种物理参数
            if 'body_mass' in self.rand_params:
                body_mass_multiplyers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit,  size=self.model.body_mass.shape)
                new_params['body_mass'] = self.init_params['body_mass'] * body_mass_multiplyers

            # body_inertia
            if 'body_inertia' in self.rand_params:
                body_inertia_multiplyers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit,  size=self.model.body_inertia.shape)
                new_params['body_inertia'] = body_inertia_multiplyers * self.init_params['body_inertia']

            # damping -> different multiplier for different dofs/joints
            if 'dof_damping' in self.rand_params:
                dof_damping_multipliers = np.array(1.3) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit, size=self.model.dof_damping.shape)
                new_params['dof_damping'] = np.multiply(self.init_params['dof_damping'], dof_damping_multipliers)

            # friction at the body components
            if 'geom_friction' in self.rand_params:
                dof_damping_multipliers = np.array(1.5) ** np.random.uniform(-self.log_scale_limit, self.log_scale_limit, size=self.model.geom_friction.shape)
                new_params['geom_friction'] = np.multiply(self.init_params['geom_friction'], dof_damping_multipliers)

            param_sets.append(new_params) #将新的参数字典添加到参数集列表里

        return param_sets

    def set_task(self, task):
        for param, param_val in task.items():#对于任务中的每一个参数
            param_variable = getattr(self.model, param)#获取模型中对应的参数
            assert param_variable.shape == param_val.shape, 'shapes of new parameter value and old one must match'#检查新旧参数的形状是否匹配
            setattr(self.model, param, param_val)
        self.cur_params = task #保存当前的参数

    def get_task(self):
        return self.cur_params

    def save_parameters(self):
        self.init_params = {} 
        #保存各种物理参数
        if 'body_mass' in self.rand_params:
            self.init_params['body_mass'] = self.model.body_mass

        # body_inertia
        if 'body_inertia' in self.rand_params:
            self.init_params['body_inertia'] = self.model.body_inertia

        # damping -> different multiplier for different dofs/joints
        if 'dof_damping' in self.rand_params:
            self.init_params['dof_damping'] = self.model.dof_damping

        # friction at the body components
        if 'geom_friction' in self.rand_params:
            self.init_params['geom_friction'] = self.model.geom_friction
        self.cur_params = self.init_params #将初始参数设置为当前参数

6.3hopper_rand_params.py

import numpy as np
from rand_param_envs.base import RandomEnv
from rand_param_envs.gym import utils
#定义了一个可以随机变化的跳跃机器人,机器人的目标是尽可能地向前跳跃
class HopperRandParamsEnv(RandomEnv, utils.EzPickle):
    def __init__(self, log_scale_limit=3.0):
        RandomEnv.__init__(self, log_scale_limit, 'hopper.xml', 4)
        utils.EzPickle.__init__(self)

    def _step(self, a):
        posbefore = self.model.data.qpos[0, 0]#获取模型地当前位置
        self.do_simulation(a, self.frame_skip)#执行模拟
        posafter, height, ang = self.model.data.qpos[0:3, 0]#获取模型地新位置,高度和角度
        alive_bonus = 1.0#存活奖励
        #计算奖励
        reward = (posafter - posbefore) / self.dt
        reward += alive_bonus
        reward -= 1e-3 * np.square(a).sum()
        s = self.state_vector()#获取状态向量
        done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and
                    (height > .7) and (abs(ang) < .2))#判断是否结束
        #获取观测
        ob = self._get_obs()
        return ob, reward, done, {}
    #定义获取观测地函数
    def _get_obs(self):
        return np.concatenate([
            self.model.data.qpos.flat[1:],
            np.clip(self.model.data.qvel.flat, -10, 10)
        ])
    #定义重置模型地函数
    def reset_model(self):
        #模型的位置设置为初始位置+一个随机的微小扰动
        qpos = self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq)
        #模型的速度设置为初始速度+一个随机的微小扰动
        qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
        #将模型的位置和速度设置为刚才计算得到的
        self.set_state(qpos, qvel)
        return self._get_obs()#返回当前观测

    def viewer_setup(self):
        #用于设置观察者(摄像机)地参数,得到合适地视角
        self.viewer.cam.trackbodyid = 2#摄像机追踪物体地id(这里是2)
        self.viewer.cam.distance = self.model.stat.extent * 0.75#摄像机与追踪物体地距离(设置为模型范围的75%)
        self.viewer.cam.lookat[2] += .8#调整摄像机在z轴方向的位置,这里是增加0.8个单位长度
        self.viewer.cam.elevation = -20#摄像机的仰角(-20度)

if __name__ == "__main__":

    env = HopperRandParamsEnv()
    tasks = env.sample_tasks(40)#生成40个任务
    while True:
        env.reset()
        env.set_task(np.random.choice(tasks))#设置一个随机任务
        print(env.model.body_mass)
        for _ in range(100):
            env.render()
            env.step(env.action_space.sample())  #选择一个随机动作

6.4pr2_env_reach.py

import numpy as np
from rand_param_envs.base import RandomEnv
from rand_param_envs.gym import utils
import os
#创建了一个PR2机器人的仿真环境
class PR2Env(RandomEnv, utils.EzPickle):

    FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'assets/pr2.xml')
    #定义了一个类变量FILE,它是PR2机器人的xml文件路径
    def __init__(self, log_scale_limit=1.):
        self.viewer = None#初始化一个viewer属性,用于后续创建视图窗口
        RandomEnv.__init__(self, log_scale_limit, 'pr2.xml', 4)#调用randomenv的初始化方法
        utils.EzPickle.__init__(self)#调用ezpickle的初始化方法

    def _get_obs(self):
        #定义获取观测的方法
        return np.concatenate([
            self.model.data.qpos.flat[:7],#获取模型的位置数据
            self.model.data.qvel.flat[:7],  # Do not include the velocity of the target (should be 0).
            self.get_tip_position().flat,#获取机器人末端执行器的位置
            self.get_vec_tip_to_goal().flat,#获取末端执行器到目标的向量
        ])

    def get_tip_position(self):
        return self.model.data.site_xpos[0]#返回末端执行器的位置

    def get_vec_tip_to_goal(self):
        tip_position = self.get_tip_position()#获取末端执行器的位置
        goal_position = self.goal#获取目标的位置
        vec_tip_to_goal = goal_position - tip_position#计算向量
        return vec_tip_to_goal

    @property#这是一个只读属性
    def goal(self):
        #返回目标的数据位置
        return self.model.data.qpos.flat[-3:]

    def _step(self, action):
        #定义执行动作的方法
        self.do_simulation(action, self.frame_skip)

        vec_tip_to_goal = self.get_vec_tip_to_goal()#执行器到目标的向量
        distance_tip_to_goal = np.linalg.norm(vec_tip_to_goal)#执行器到目标的距离

        reward = - distance_tip_to_goal#计算奖励(负的距离)

        state = self.state_vector()#获取当前状态向量
        notdone = np.isfinite(state).all()#检查状态向量是否全部是有限的
        done = not notdone#如果非有限,则表示仿真结束

        ob = self._get_obs()#获取当前观测

        return ob, reward, done, {}

    def reset_model(self):
        qpos = self.init_qpos#获取初始位置
        qvel = self.init_qvel#获取初始速度
        goal = np.random.uniform((0.2, -0.4, 0.5), (0.5, 0.4, 1.5))#随机生成一个目标位置
        qpos[-3:] = goal#设置目标位置
        qpos[:7] += self.np_random.uniform(low=-.005, high=.005,  size=7)#给初始位置增加随机扰动
        qvel[:7] += self.np_random.uniform(low=-.005, high=.005,  size=7)#给初始速度增加随机扰动
        self.set_state(qpos, qvel)#设置模型的状态
        return self._get_obs()#返回状态

    def viewer_setup(self):
        self.viewer.cam.distance = self.model.stat.extent * 2
        # self.viewer.cam.lookat[2] += .8
        self.viewer.cam.elevation = -50
        # self.viewer.cam.lookat[0] = self.model.stat.center[0]
        # self.viewer.cam.lookat[1] = self.model.stat.center[1]
        # self.viewer.cam.lookat[2] = self.model.stat.center[2]


if __name__ == "__main__":

    env = PR2Env()
    tasks = env.sample_tasks(40)
    while True:
        env.reset()
        env.set_task(np.random.choice(tasks))
        print(env.model.body_mass)
        for _ in range(100):
            env.render()
            env.step(env.action_space.sample())

试了一下只运行这个代码,报错

pip3 install mujoco_py==0.5.7
pip3 install gym==0.9.3
参考windows 上安装 mujoco_py - 简书

Mujoco150 windows11 下安装实践 - 知乎

报错变成了

参考c++ - Python binding for MuJoCo physics library using mujoco-py package - Stack Overflow

终于!!!!!

6.5walker2d_rand_params.py

import numpy as np
from rand_param_envs.base import RandomEnv
from rand_param_envs.gym import utils
#定义了一个二维行走机器人
class Walker2DRandParamsEnv(RandomEnv, utils.EzPickle):
    def __init__(self, log_scale_limit=3.0):
        RandomEnv.__init__(self, log_scale_limit, 'walker2d.xml', 5)#5是帧跳过参数
        utils.EzPickle.__init__(self)

    def _step(self, a):
        posbefore = self.model.data.qpos[0, 0]
        self.do_simulation(a, self.frame_skip)
        posafter, height, ang = self.model.data.qpos[0:3, 0]
        alive_bonus = 1.0
        reward = ((posafter - posbefore) / self.dt)
        reward += alive_bonus#奖励是位置的变化量除以时间步长加上存活奖励
        reward -= 1e-3 * np.square(a).sum()#奖励减去动作的平方和的千分之一
        done = not (height > 0.8 and height < 2.0 and
                    ang > -1.0 and ang < 1.0)#如果高度不在0.8-2.0之间,或者角度不在-1到1之间,结束
        ob = self._get_obs()
        return ob, reward, done, {}

    def _get_obs(self):
        qpos = self.model.data.qpos
        qvel = self.model.data.qvel
        return np.concatenate([qpos[1:], np.clip(qvel, -10, 10)]).ravel()#返回位置和速度的连接,速度被限制在-10到10之间

    def reset_model(self):
        self.set_state(
            self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq),
            self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
        )
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.trackbodyid = 2
        self.viewer.cam.distance = self.model.stat.extent * 0.5
        self.viewer.cam.lookat[2] += .8
        self.viewer.cam.elevation = -20

if __name__ == "__main__":

    env = Walker2DRandParamsEnv()
    tasks = env.sample_tasks(40)
    while True:
        env.reset()
        env.set_task(np.random.choice(tasks))
        print(env.model.body_mass)
        for _ in range(100):
            env.render()
            env.step(env.action_space.sample())  # take a random action

虽然它们个个都很鬼畜。

main.py

import argparse
import torch
import os
import time
import sys
import numpy as np
from collections import deque
import random
from misc.utils import create_dir, dump_to_json, CSVWriter 
from misc.torch_utility import get_state
from misc.utils import set_global_seeds, safemean, read_json
from misc import logger
from algs.MQL.buffer import Buffer

parser = argparse.ArgumentParser()

# Optim params
parser.add_argument('--lr', type=float, default=0.0003, help = 'Learning rate')
parser.add_argument('--replay_size', type=int, default = 1e6, help ='Replay buffer size int(1e6)')
parser.add_argument('--ptau', type=float, default=0.005 , help = 'Interpolation factor in polyak averaging')
parser.add_argument('--gamma', type=float, default=0.99, help = 'Discount factor [0,1]')
parser.add_argument("--burn_in", default=1e4, type=int, help = 'How many time steps purely random policy is run for') 
parser.add_argument("--total_timesteps", default=5e6, type=float, help = 'Total number of timesteps to train on')
parser.add_argument("--expl_noise", default=0.2, type=float, help='Std of Gaussian exploration noise')
parser.add_argument("--batch_size", default=256, type=int, help = 'Batch size for both actor and critic')
parser.add_argument("--policy_noise", default=0.3, type=float, help =' Noise added to target policy during critic update')
parser.add_argument("--noise_clip", default=0.5, type=float, help='Range to clip target policy noise')
parser.add_argument("--policy_freq", default=2, type=int, help='Frequency of delayed policy updates')
parser.add_argument('--hidden_sizes', nargs='+', type=int, default = [300, 300], help = 'indicates hidden size actor/critic')

# General params
parser.add_argument('--env_name', type=str, default='ant-goal')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--alg_name', type=str, default='mql')

parser.add_argument('--disable_cuda', default=False, action='store_true')
parser.add_argument('--cuda_deterministic', default=False, action='store_true')
parser.add_argument("--gpu_id", default=0, type=int)

parser.add_argument('--log_id', default='empty')
parser.add_argument('--check_point_dir', default='./ck')
parser.add_argument('--log_dir', default='./log_dir')
parser.add_argument('--log_interval', type=int, default=10, help='log interval, one log per n updates')
parser.add_argument('--save_freq', type=int, default = 250)
parser.add_argument("--eval_freq", default=5e3, type=float, help = 'How often (time steps) we evaluate')    

# Env
parser.add_argument('--env_configs', default='./configs/pearl_envs.json')
parser.add_argument('--max_path_length', type=int, default = 200)
parser.add_argument('--enable_train_eval', default=False, action='store_true')
parser.add_argument('--enable_promp_envs', default=False, action='store_true')
parser.add_argument('--num_initial_steps',  type=int, default = 1000)
parser.add_argument('--unbounded_eval_hist', default=False, action='store_true')

#context
parser.add_argument('--hiddens_conext', nargs='+', type=int, default = [30], help = 'indicates hidden size of context next')
parser.add_argument('--enable_context', default=True, action='store_true')
parser.add_argument('--only_concat_context', type=int, default = 3, help =' use conext')
parser.add_argument('--num_tasks_sample', type=int, default = 5)
parser.add_argument('--num_train_steps', type=int, default = 500)
parser.add_argument('--min_buffer_size', type=int, default = 100000, help = 'this indicates a condition to start using num_train_steps')
parser.add_argument('--history_length', type=int, default = 30)

#other params
parser.add_argument('--beta_clip', default=1.0, type=float, help='Range to clip beta term in CSC')
parser.add_argument('--snapshot_size', type=int, default = 2000, help ='Snapshot size for a task')
parser.add_argument('--prox_coef', default=0.1, type=float, help ='Prox lambda')
parser.add_argument('--meta_batch_size', default=10, type=int, help ='Meta batch size: number of sampled tasks per itr')
parser.add_argument('--enable_adaptation', default=True, action='store_true')
parser.add_argument('--main_snap_iter_nums', default=100, type=int, help ='how many times adapt using train task but with csc')
parser.add_argument('--snap_iter_nums', default=10, type=int, help ='how many times adapt using eval task')
parser.add_argument('--type_of_training', default='td3', help = 'td3')
parser.add_argument('--lam_csc', default=0.50, type=float, help='logisitc regression reg, smaller means stronger reg')
parser.add_argument('--use_ess_clipping', default=False, action='store_true')
parser.add_argument('--enable_beta_obs_cxt', default=False, action='store_true', help='if true concat obs + context')
parser.add_argument('--sampling_style', default='replay', help = 'replay')
parser.add_argument('--sample_mult',  type=int, default = 5, help ='sample multipler of main_iter for adapt method')
parser.add_argument('--use_epi_len_steps', default=True, action='store_true')
parser.add_argument('--use_normalized_beta', default=False, action='store_true', help = 'normalized beta_score')
parser.add_argument('--reset_optims', default=False, action='store_true', help = 'init optimizers at the start of adaptation')
parser.add_argument('--lr_milestone', default = -1, type=int, help = 'reduce learning rate after this epoch')
parser.add_argument('--lr_gamma', default = 0.8, type=float, help = 'learning rate decay')

def update_lr(eparams, iter_num, alg_mth):
    #######
    # initial_lr if i < reduce_lr
    # otherwise initial_lr * lr_gamma
    #######
    if iter_num > eparams.lr_milestone:
        new_lr = eparams.lr * eparams.lr_gamma

        for param_group in alg_mth.actor_optimizer.param_groups:
            param_group['lr'] = new_lr

        for param_group in alg_mth.critic_optimizer.param_groups:
            param_group['lr'] = new_lr
        print("---------")
        print("Actor (updated_lr):\n ",  alg_mth.actor_optimizer)
        print("Critic (updated_lr):\n ", alg_mth.critic_optimizer)
        print("---------")

def take_snapshot(args, ck_fname_part, model, update):
    '''
        This fucntion just save the current model and save some other info
    '''
    fname_ck =  ck_fname_part + '.pt'
    fname_json =  ck_fname_part + '.json'
    curr_state_actor = get_state(model.actor)
    curr_state_critic = get_state(model.critic)

    print('Saving a checkpoint for iteration %d in %s' % (update, fname_ck))
    checkpoint = {
                    'args': args.__dict__,
                    'model_states_actor': curr_state_actor,
                    'model_states_critic': curr_state_critic,
                 }
    torch.save(checkpoint, fname_ck)

    del checkpoint['model_states_actor']
    del checkpoint['model_states_critic']
    del curr_state_actor
    del curr_state_critic
    
    dump_to_json(fname_json, checkpoint)

def setup_logAndCheckpoints(args):

    # create folder if not there
    create_dir(args.check_point_dir)

    fname = str.lower(args.env_name) + '_' + args.alg_name + '_' + args.log_id
    fname_log = os.path.join(args.log_dir, fname)
    fname_eval = os.path.join(fname_log, 'eval.csv')
    fname_adapt = os.path.join(fname_log, 'adapt.csv')

    return os.path.join(args.check_point_dir, fname), fname_log, fname_eval, fname_adapt

def make_env(eparams):
    '''
        This function builds env
    '''
    # since env contains np/sample function, need to set random seed here
    # set random seed
    random.seed(args.seed)
    np.random.seed(args.seed)

    ################
    # this is based on PEARL paper that fixes set of sampels
    ################
    from misc.env_meta import build_PEARL_envs
    env = build_PEARL_envs(
                           seed = eparams.seed,
                           env_name = eparams.env_name,
                           params = eparams,
                           )

    return env

def sample_env_tasks(env, eparams):
    '''
        Sample env tasks
    '''
    if eparams.enable_promp_envs == True:
        # task list created as [ train_task,..., train_task ,eval_task,..., eval_task]
        train_tasks = env.sample_tasks(eparams.n_train_tasks)
        eval_tasks  = env.sample_tasks(eparams.n_eval_tasks)

    else:
        # task list created as [ train_task,..., train_task ,eval_task,..., eval_task]
        tasks = env.get_all_task_idx()
        train_tasks = list(tasks[:eparams.n_train_tasks])
        eval_tasks = list(tasks[-eparams.n_eval_tasks:])

    return train_tasks, eval_tasks

def config_tasks_envs(eparams):
    '''
        Configure tasks parameters.
        Envs params and task parameters based on pearl paper:
        args like followings will be added:
        n_train_tasks   2
        n_eval_tasks    2
        n_tasks 2
        randomize_tasks true
        low_gear    false
        forward_backward    true
        num_evals   4
        num_steps_per_task  400
        num_steps_per_eval  400
        num_train_steps_per_itr 4000
    '''
    configs = read_json(eparams.env_configs)[eparams.env_name]
    temp_params = vars(eparams)
    for k, v in configs.items():
            temp_params[k] = v

def evaluate_policy(eval_env,
                    policy,
                    eps_num,
                    itr,
                    etasks,
                    eparams,
                    meta_learner = None,
                    train_tasks_buffer = None,
                    train_replay_buffer = None,
                    msg ='Evaluation'):
    '''
        runs policy for X episodes and returns average reward
    '''
    if eparams.unbounded_eval_hist == True: # increase seq length to max_path_length
        eval_hist_len = eparams.max_path_length
        print('Eval uses unbounded_eval_hist of length: ', eval_hist_len)

    else:
        eval_hist_len = eparams.history_length
        print('Eval uses history of length: ', eval_hist_len)

    if eparams.enable_promp_envs == True:
        etasks  = eval_env.sample_tasks(eparams.n_eval_tasks)

    ############# adaptation step #############
    if meta_learner and eparams.enable_adaptation == True:
        meta_learner.save_model_states()
    ############# ############### #############

    all_task_rewards = []
    dc_rewards = []

    for tidx in etasks:
        if eparams.enable_promp_envs == True:
            eval_env.set_task(tidx)

        else:
            eval_env.reset_task(tidx)

        ############# adaptation step #############
        if  meta_learner and eparams.enable_adaptation == True:
            eval_task_buffer, avg_data_collection = collect_data_for_adaptaion(eval_env, policy, tidx, eparams)
            stats_main, stats_csv = meta_learner.adapt(train_replay_buffer = train_replay_buffer,
                                                       train_tasks_buffer = train_tasks_buffer,
                                                       eval_task_buffer = eval_task_buffer,
                                                       task_id = tidx,
                                                       snap_iter_nums = eparams.snap_iter_nums,
                                                       main_snap_iter_nums = eparams.main_snap_iter_nums,
                                                       sampling_style = eparams.sampling_style,
                                                       sample_mult = eparams.sample_mult
                                                       )
            dc_rewards.append(avg_data_collection)
            print('--------Adaptation-----------')
            print('Task: ', tidx)
            print(("critic_loss: %.4f \n\ractor_loss: %.4f \n\rNo beta_score: %.4f ") %
                  (stats_csv['critic_loss'], stats_csv['actor_loss'], stats_csv['beta_score']))

            print(("\rsamples for CSC: (%d, %d) \n\rAccuracy on train: %.4f \n\rsnap_iter: %d ") %
                  (stats_csv['csc_info'][0], stats_csv['csc_info'][1], stats_csv['csc_info'][2], stats_csv['snap_iter']))
            print(("\rmain_critic_loss: %.4f \n\rmain_actor_loss: %.4f \n\rmain_beta_score: %.4f ") %
                   (stats_main['critic_loss'], stats_main['actor_loss'], stats_main['beta_score']))
            print(("\rmain_prox_critic %.4f \n\rmain_prox_actor: %.4f")%(stats_main['prox_critic'], stats_main['prox_actor']))

            if 'avg_prox_coef' in stats_main:
                print(("\ravg_prox_coef: %.4f" %(stats_main['avg_prox_coef'])))

            print('-----------------------------')
        ############# ############### #############

        avg_reward = 0
        for _ in range(eparams.num_evals):
            obs = eval_env.reset()
            done = False
            step = 0

            ### history ####
            rewards_hist = deque(maxlen=eval_hist_len)
            actions_hist = deque(maxlen=eval_hist_len)
            obsvs_hist   = deque(maxlen=eval_hist_len)

            rewards_hist.append(0)
            obsvs_hist.append(obs.copy())

            rand_action = np.random.normal(0, eparams.expl_noise, size=eval_env.action_space.shape[0])
            rand_action = rand_action.clip(eval_env.action_space.low, eval_env.action_space.high)
            actions_hist.append(rand_action.copy())

            while not done and step < eparams.max_path_length :

                np_pre_actions = np.asarray(actions_hist, dtype=np.float32).flatten() #(hist, action_dim) => (hist *action_dim,)
                np_pre_rewards = np.asarray(rewards_hist, dtype=np.float32) #(hist, )
                np_pre_obsvs  = np.asarray(obsvs_hist, dtype=np.float32).flatten() #(hist, action_dim) => (hist *action_dim,)
                action = policy.select_action(np.array(obs), np.array(np_pre_actions), np.array(np_pre_rewards), np.array(np_pre_obsvs))
                new_obs, reward, done, _ = eval_env.step(action)
                avg_reward += reward
                step += 1

                # new becomes old
                rewards_hist.append(reward)
                actions_hist.append(action.copy())
                obsvs_hist.append(obs.copy())

                obs = new_obs.copy()

        avg_reward /= eparams.num_evals
        all_task_rewards.append(avg_reward)

        ############# adaptation step #############
        # Roll-back
        ############# ############### #############
        if meta_learner and eparams.enable_adaptation == True:
            meta_learner.rollback()

            ############## add adapt data to a csv file
            log_data_adp = {}
            for k, v in stats_csv.items():
                if k in eparams.adapt_csv_hearder:
                    log_data_adp[k] = stats_csv[k]

            log_data_adp['csc_samples_neg'] = stats_csv['csc_info'][0]
            log_data_adp['csc_samples_pos'] = stats_csv['csc_info'][1]
            log_data_adp['train_acc'] = stats_csv['csc_info'][2]
            log_data_adp['avg_rewards'] = avg_reward
            log_data_adp['one_raw_reward'] = avg_data_collection
            log_data_adp['tidx'] = tidx
            log_data_adp['eps_num'] = eps_num
            log_data_adp['iter'] = itr

            for k in stats_main.keys():
                if k in eparams.adapt_csv_hearder:
                    log_data_adp['main_' + k] = stats_main[k]
                elif 'main_' + k in eparams.adapt_csv_hearder:
                    log_data_adp['main_' + k] = stats_main[k]

            adapt_csv_stats.write(log_data_adp)
            ##############

    if meta_learner and eparams.enable_adaptation == True:
        msg += ' *** with Adapation *** '
        print('Avg rewards (only one eval loop) for all tasks before adaptation ', np.mean(dc_rewards))

    print("---------------------------------------")
    print("%s over %d episodes of %d tasks in episode num %d and nupdates %d: %f" \
           % (msg, eparams.num_evals, len(etasks), eps_num, itr, np.mean(all_task_rewards)))
    print("---------------------------------------")
    return np.mean(all_task_rewards)

def collect_data_for_adaptaion(eval_env, policy, tidx, eparams):

    '''
        Collect data for adaptation adaptation
    '''
    ###
    # Step 0: Create eval buffers
    ###
    eval_task_buffer = MultiTasksSnapshot(max_size = args.snapshot_size)
    eval_task_buffer.init([tidx])

    ###
    # Step 1: Define some vars
    ###
    step = 0
    avg_reward = 0
    prev_reward = 0
    obs = eval_env.reset()
    done = False

    ### history ####
    rewards_hist = deque(maxlen=eparams.history_length)
    actions_hist = deque(maxlen=eparams.history_length)
    obsvs_hist   = deque(maxlen=eparams.history_length)

    next_hrews = deque(maxlen=eparams.history_length)
    next_hacts = deque(maxlen=eparams.history_length)
    next_hobvs = deque(maxlen=eparams.history_length)

    zero_action = np.zeros(eval_env.action_space.shape[0])
    zero_obs    = np.zeros(obs.shape)
    for _ in range(eparams.history_length):
        rewards_hist.append(0)
        actions_hist.append(zero_action.copy())
        obsvs_hist.append(zero_obs.copy())

        # same thing for next_h*
        next_hrews.append(0)
        next_hacts.append(zero_action.copy())
        next_hobvs.append(zero_obs.copy())

    rewards_hist.append(0)
    obsvs_hist.append(obs.copy())

    rand_action = np.random.normal(0, eparams.expl_noise, size=eval_env.action_space.shape[0])
    rand_action = rand_action.clip(eval_env.action_space.low, eval_env.action_space.high)
    actions_hist.append(rand_action.copy())

    while not done and step < eparams.max_path_length :

        np_pre_actions = np.asarray(actions_hist, dtype=np.float32).flatten() #(hist, action_dim) => (hist *action_dim,)
        np_pre_rewards = np.asarray(rewards_hist, dtype=np.float32) # (hist, )
        np_pre_obsvs  = np.asarray(obsvs_hist, dtype=np.float32).flatten() #(hist, action_dim) => (hist *action_dim,)
        action = policy.select_action(np.array(obs), np.array(np_pre_actions), np.array(np_pre_rewards), np.array(np_pre_obsvs))
        new_obs, reward, done, _ = eval_env.step(action)
        avg_reward += reward

        if step + 1 == args.max_path_length:
            done_bool = 0

        else:
            done_bool = float(done)

        ###############
        next_hrews.append(reward)
        next_hacts.append(action.copy())
        next_hobvs.append(obs.copy())

        # np_next_hacts and np_next_hrews are required for TD3 alg
        np_next_hacts = np.asarray(next_hacts, dtype=np.float32).flatten()  # (hist, action_dim) => (hist *action_dim,)
        np_next_hrews = np.asarray(next_hrews, dtype=np.float32) # (hist, )
        np_next_hobvs = np.asarray(next_hobvs, dtype=np.float32).flatten() # (hist, )

        eval_task_buffer.add(tidx, (obs, new_obs, action, reward, done_bool,
                            np_pre_actions, np_pre_rewards, np_pre_obsvs,
                            np_next_hacts, np_next_hrews, np_next_hobvs))

        step += 1

        # new becomes old
        rewards_hist.append(reward)
        actions_hist.append(action.copy())
        obsvs_hist.append(obs.copy())

        obs = new_obs.copy()

    return eval_task_buffer, avg_reward

def adjust_number_train_iters(buffer_size, num_train_steps, bsize, min_buffer_size,
                              episode_timesteps, use_epi_len_steps = False):
    '''
        This adjusts number of gradient updates given sometimes there
        is not enough data in buffer
    '''
    if use_epi_len_steps == True and episode_timesteps > 1 and buffer_size < min_buffer_size:
        return episode_timesteps

    if buffer_size < num_train_steps or buffer_size < min_buffer_size:
        temp = int( buffer_size/ (bsize) % num_train_steps ) + 1

        if temp < num_train_steps:
            num_train_steps = temp

    return num_train_steps

if __name__ == "__main__":

    args = parser.parse_args()
    print('------------')
    print(args.__dict__)
    print('------------')

    print('Read Tasks/Env config params and Update args')
    config_tasks_envs(args)
    print(args.__dict__)

    # if use mujoco-v2, then xml file should be ignore
    if ('-v2' in args.env_name):
        print('**** XML file is ignored since it is -v2 ****')

    ##############################
    #### Generic setups
    ##############################
    CUDA_AVAL = torch.cuda.is_available()

    if not args.disable_cuda and CUDA_AVAL: 
        gpu_id = "cuda:" + str(args.gpu_id)
        device = torch.device(gpu_id)
        print("**** Yayy we use GPU %s ****" % gpu_id)

    else:                                                   
        device = torch.device('cpu')
        print("**** No GPU detected or GPU usage is disabled, sorry! ****")

    ####
    # train and evalution checkpoints, log folders, ck file names
    create_dir(args.log_dir, cleanup = True)
    # create folder for save checkpoints
    ck_fname_part, log_file_dir, fname_csv_eval, fname_adapt = setup_logAndCheckpoints(args)
    logger.configure(dir = log_file_dir)
    wrt_csv_eval = None

    ##############################
    #### Init env, model, alg, batch generator etc
    #### Step 1: build env
    #### Step 2: Build model
    #### Step 3: Initiate Alg e.g. a2c
    #### Step 4: Initiate batch/rollout generator  
    ##############################

    ##### env setup #####
    env = make_env(args)

    ######### SEED ##########
    #  build_env already calls set seed,
    # Set seed the RNG for all devices (both CPU and CUDA)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if not args.disable_cuda and CUDA_AVAL and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        print("****** cudnn.deterministic is set ******")

    ######### Build Networks
    max_action = float(env.action_space.high[0])
    if len(env.observation_space.shape) == 1:
        import models.networks as net

        ######
        # This part to add context network
        ######
        if args.enable_context == True:
            reward_dim = 1
            input_dim_context =   env.action_space.shape[0] + reward_dim
            args.output_dim_conext =  (env.action_space.shape[0] + reward_dim) * 2


            if args.only_concat_context == 3: # means use LSTM with action_reward_state as an input
                input_dim_context = env.action_space.shape[0] + reward_dim + env.observation_space.shape[0]
                actor_idim = [env.observation_space.shape[0] + args.hiddens_conext[0]]
                args.output_dim_conext = args.hiddens_conext[0]
                dim_others = args.hiddens_conext[0]

            else:
                raise ValueError(" %d args.only_concat_context is not supported" % (args.only_concat_context))

        else:
            actor_idim = env.observation_space.shape
            dim_others = 0
            input_dim_context = None
            args.output_dim_conext = 0

        actor_net = net.Actor(action_space = env.action_space,
                              hidden_sizes =args.hidden_sizes,
                              input_dim = actor_idim,
                              max_action = max_action,
                              enable_context = args.enable_context,
                              hiddens_dim_conext = args.hiddens_conext,
                              input_dim_context = input_dim_context,
                              output_conext = args.output_dim_conext,
                              only_concat_context = args.only_concat_context,
                              history_length = args.history_length,
                              obsr_dim = env.observation_space.shape[0],
                              device = device
                              ).to(device)

        actor_target_net = net.Actor(action_space = env.action_space,
                                    hidden_sizes =args.hidden_sizes,
                                    input_dim = actor_idim,
                                    max_action = max_action,
                                    enable_context = args.enable_context,
                                    hiddens_dim_conext = args.hiddens_conext,
                                    input_dim_context = input_dim_context,
                                    output_conext = args.output_dim_conext,
                                    only_concat_context = args.only_concat_context,
                                    history_length = args.history_length,
                                    obsr_dim = env.observation_space.shape[0],
                                    device = device
                                     ).to(device)

        critic_net = net.Critic(action_space = env.action_space,
                                hidden_sizes =args.hidden_sizes,
                                input_dim = env.observation_space.shape,
                                enable_context = args.enable_context,
                                dim_others = dim_others,
                                hiddens_dim_conext = args.hiddens_conext,
                                input_dim_context = input_dim_context,
                                output_conext = args.output_dim_conext,
                                only_concat_context = args.only_concat_context,
                                history_length = args.history_length,
                                obsr_dim = env.observation_space.shape[0],
                                device = device
                                ).to(device)

        critic_target_net = net.Critic(action_space = env.action_space,
                                        hidden_sizes =args.hidden_sizes,
                                        input_dim = env.observation_space.shape,
                                        enable_context = args.enable_context,
                                        dim_others = dim_others,
                                        hiddens_dim_conext = args.hiddens_conext,
                                        input_dim_context = input_dim_context,
                                        output_conext = args.output_dim_conext,
                                        only_concat_context = args.only_concat_context,
                                        history_length = args.history_length,
                                        obsr_dim = env.observation_space.shape[0],
                                        device = device
                                       ).to(device)

    else:
        raise ValueError("%s model is not supported for %s env" % (args.env_name, env.observation_space.shape))

    ######
    # algorithm setup
    ######

    # init replay buffer
    replay_buffer = Buffer(max_size = args.replay_size)
    
    if str.lower(args.alg_name) == 'mql':

        # tdm3 uses specific runner
        from misc.runner_multi_snapshot import Runner
        from algs.MQL.multi_tasks_snapshot import MultiTasksSnapshot
        import algs.MQL.mql as alg

        alg = alg.MQL(actor = actor_net,
                        actor_target = actor_target_net,
                        critic = critic_net,
                        critic_target = critic_target_net,
                        lr = args.lr,
                        gamma=args.gamma,
                        ptau = args.ptau,
                        policy_noise = args.policy_noise,
                        noise_clip = args.noise_clip,
                        policy_freq = args.policy_freq,
                        batch_size = args.batch_size,
                        max_action = max_action,
                        beta_clip = args.beta_clip,
                        prox_coef = args.prox_coef,
                        type_of_training = args.type_of_training,
                        lam_csc = args.lam_csc,
                        use_ess_clipping = args.use_ess_clipping,
                        enable_beta_obs_cxt = args.enable_beta_obs_cxt,
                        use_normalized_beta = args.use_normalized_beta,
                        reset_optims = args.reset_optims,
                        device = device,
                    )
        ##### rollout/batch generator
        tasks_buffer = MultiTasksSnapshot(max_size = args.snapshot_size)
        rollouts = Runner(env = env,
                          model = alg,
                          replay_buffer = replay_buffer,
                          tasks_buffer = tasks_buffer,
                          burn_in = args.burn_in,
                          expl_noise = args.expl_noise,
                          total_timesteps = args.total_timesteps,
                          max_path_length = args.max_path_length,
                          history_length = args.history_length,
                          device = device)

    else:
        raise ValueError("%s alg is not supported" % args.alg_name)


    ##### rollout/batch generator
    train_tasks, eval_tasks = sample_env_tasks(env, args)

    tasks_buffer.init(train_tasks)
    alg.set_tasks_list(train_tasks)

    print('-----------------------------')
    print("Name of env:", args.env_name)
    print("Observation_space:", env.observation_space )
    print("Action space:", env.action_space )
    print("Tasks:", args.n_tasks )
    print("Train tasks:", args.n_train_tasks  )
    print("Eval tasks:", args.n_eval_tasks)
    print("######### Using Hist len %d #########" % (args.history_length))

    if args.enable_promp_envs == True:
        print("********* Using ProMp Envs *********")
    else:
        print("@@@@@@@@@ Using PEARL Envs @@@@@@@@@")
    print('----------------------------')

    ##############################
    # Train and eval
    #############################
    # define some req vars
    timesteps_since_eval = 0
    episode_num = 0
    update_iter = 0
    sampling_loop = 0

    # episode_stats for raw rewards
    epinfobuf = deque(maxlen=args.n_train_tasks)
    epinfobuf_v2 = deque(maxlen=args.n_train_tasks)

    # just to keep params
    take_snapshot(args, ck_fname_part, alg, 0)

    # Evaluate untrained policy
    eval_results = [evaluate_policy(env, alg, episode_num, update_iter, etasks=eval_tasks, eparams=args)] 
    if args.enable_train_eval:
        train_subset = np.random.choice(train_tasks, len(eval_tasks))
        train_subset_tasks_eval = evaluate_policy(env, alg, episode_num, update_iter,
                                                  etasks=train_subset,
                                                  eparams=args,
                                                  msg ='Train-Eval')
    else:
        train_subset_tasks_eval = 0

    wrt_csv_eval = CSVWriter(fname_csv_eval, {'nupdates':update_iter,
                                              'total_timesteps':update_iter,
                                              'eval_eprewmean':eval_results[0],
                                              'train_eprewmean':train_subset_tasks_eval,
                                              'episode_num':episode_num,
                                              'sampling_loop':sampling_loop
                                              })
    wrt_csv_eval.write({'nupdates':update_iter,
                      'total_timesteps':update_iter,
                      'eval_eprewmean':eval_results[0],
                      'train_eprewmean':train_subset_tasks_eval,
                      'episode_num':episode_num,
                      'sampling_loop':sampling_loop
                      })
    ## keep track of adapt stats
    if args.enable_adaptation == True:
        args.adapt_csv_hearder =  dict.fromkeys(['eps_num', 'iter','critic_loss', 'actor_loss',
                                                 'csc_samples_neg','csc_samples_pos','train_acc',
                                                 'snap_iter','beta_score','main_critic_loss',
                                                 'main_actor_loss', 'main_beta_score', 'main_prox_critic',
                                                 'main_prox_actor','main_avg_prox_coef',
                                                 'tidx', 'avg_rewards', 'one_raw_reward'])
        adapt_csv_stats = CSVWriter(fname_adapt, args.adapt_csv_hearder)

    # Start total timer
    tstart = time.time()

    ####
    # First fill up the replay buffer with all tasks
    ####
    max_cold_start = np.maximum(args.num_initial_steps * args.n_train_tasks, args.burn_in)
    print('Start burnining for at least %d' % max_cold_start)
    keep_sampling = True
    avg_length = 0
    while (keep_sampling == True):

        for idx in range(args.n_train_tasks):
            tidx = train_tasks[idx]
            if args.enable_promp_envs == True:
                env.set_task(tidx) # tidx for promp is task value

            else:
                # for pearl env, tidx == idx
                env.reset_task(tidx) # tidx here is an id

            data = rollouts.run(update_iter, keep_burning = True, task_id=tidx,
                                early_leave = args.max_path_length/4) # data collection is way important now
            timesteps_since_eval += data['episode_timesteps']
            update_iter += data['episode_timesteps']
            epinfobuf.extend(data['epinfos'])
            epinfobuf_v2.extend(data['epinfos'])
            episode_num += 1
            avg_length += data['episode_timesteps']

            if update_iter >= max_cold_start:
                keep_sampling = False
                break

    print('There are %d samples in buffer now' % replay_buffer.size_rb())
    print('Average length %.2f for %d episode_nums for %d max_cold_start steps' % (avg_length/episode_num, episode_num, max_cold_start))
    print('Episode_nums/tasks %.2f and avg_len/tasks %.2f ' % (episode_num/args.n_train_tasks, avg_length/args.n_train_tasks))
    avg_epi_length = int(avg_length/episode_num)
    # already seen all tasks once
    sampling_loop = 1

    ####
    # Train and eval main loop
    ####
    train_iter = 0 
    lr_updated = False
    while update_iter < args.total_timesteps:

        if args.enable_promp_envs:
            train_tasks = env.sample_tasks(args.n_train_tasks)
            train_indices = train_tasks.copy()

        else:
            #shuffle the ind
            train_indices = np.random.choice(train_tasks, len(train_tasks))

        for tidx in train_indices:


            ######
            # update learning rate
            ######
            if args.lr_milestone > -1 and lr_updated == False and update_iter > args.lr_milestone:
                update_lr(args, update_iter, alg)
                lr_updated = True

            #######
            # run training to calculate loss, run backward, and update params
            #######
            stats_csv = None

            #adjust training steps
            adjusted_no_steps = adjust_number_train_iters(buffer_size = replay_buffer.size_rb(),
                                     num_train_steps = args.num_train_steps,
                                     bsize = args.batch_size,
                                     min_buffer_size = args.min_buffer_size,
                                     episode_timesteps = avg_epi_length,
                                     use_epi_len_steps = args.use_epi_len_steps)

            alg_stats, stats_csv = alg.train(replay_buffer = replay_buffer,
                                      iterations = adjusted_no_steps,
                                      tasks_buffer = tasks_buffer,
                                      train_iter = train_iter,
                                      task_id = tidx
                                      )
            train_iter += 1
            #######
            # logging
            #######
            nseconds = time.time() - tstart
            # Calculate the fps (frame per second)
            fps = int(( update_iter) / nseconds)

            if ((episode_num % args.log_interval == 0 or episode_num % len(train_tasks)/2 == 0) or episode_num == 1 ):
                logger.record_tabular("nupdates", update_iter)
                logger.record_tabular("fps", fps)
                logger.record_tabular("total_timesteps", update_iter)
                logger.record_tabular("critic_loss", float(alg_stats['critic_loss']))
                logger.record_tabular("actor_loss", float(alg_stats['actor_loss']))
                logger.record_tabular("episode_reward", float(data['episode_reward']))
                logger.record_tabular('eprewmean', float(safemean([epinfo['r'] for epinfo in epinfobuf])))
                logger.record_tabular('eplenmean', float(safemean([epinfo['l'] for epinfo in epinfobuf])))
                logger.record_tabular("episode_num", episode_num)
                logger.record_tabular("sampling_loop", sampling_loop)
                logger.record_tabular("buffer_size", replay_buffer.size_rb())
                logger.record_tabular("adjusted_no_steps", adjusted_no_steps)

                if 'actor_mmd_loss' in alg_stats:
                    logger.record_tabular("critic_mmd_loss", float(alg_stats['critic_mmd_loss']))
                    logger.record_tabular("actor_mmd_loss", float(alg_stats['actor_mmd_loss']))

                if 'beta_score' in alg_stats:
                     logger.record_tabular("beta_score", float(alg_stats['beta_score']))

                logger.dump_tabular()
                print(("Total T: %d Episode Num: %d Episode Len: %d Reward: %f") %
                      (update_iter, episode_num, data['episode_timesteps'], data['episode_reward']))

                #print out some info about CSC
                if stats_csv:
                    print(("CSC info:  critic_loss: %.4f actor_loss: %.4f No beta_score: %.4f ") %
                          (stats_csv['critic_loss'], stats_csv['actor_loss'], stats_csv['beta_score']))
                    if 'csc_info' in stats_csv:
                        print(("Number of examples used for CSC, prediction accuracy on train, and snap Iter: single: %d multiple tasks: %d  acc: %.4f snap_iter: %d ") %
                            (stats_csv['csc_info'][0], stats_csv['csc_info'][1], stats_csv['csc_info'][2], stats_csv['snap_iter']))
                        print(("Prox info: prox_critic %.4f prox_actor: %.4f")%(alg_stats['prox_critic'], alg_stats['prox_actor']))

                    if 'avg_prox_coef' in alg_stats and 'csc_info' in stats_csv:
                        print(("\ravg_prox_coef: %.4f" %(alg_stats['avg_prox_coef'])))

            #######
            # run eval
            #######
            if timesteps_since_eval >= args.eval_freq:
                timesteps_since_eval %= args.eval_freq

                if args.enable_adaptation == True:
                    eval_temp = evaluate_policy(env, alg, episode_num, update_iter,
                                                etasks=eval_tasks, eparams=args,
                                                meta_learner = alg,
                                                train_tasks_buffer = tasks_buffer,
                                                train_replay_buffer = replay_buffer)

                else:
                    eval_temp = evaluate_policy(env, alg, episode_num, update_iter, etasks=eval_tasks, eparams=args)

                eval_results.append(eval_temp)

                # Eval subset of train tasks
                if args.enable_train_eval:

                    if args.enable_promp_envs == False:
                        train_subset = np.random.choice(train_tasks, len(eval_tasks))

                    else:
                        train_subset = None

                    train_subset_tasks_eval = evaluate_policy(env, alg, episode_num, update_iter,
                                                              etasks=train_subset,
                                                              eparams=args,
                                                              msg ='Train-Eval')
                else:
                    train_subset_tasks_eval = 0

                # dump results
                wrt_csv_eval.write({'nupdates':update_iter,
                                   'total_timesteps':update_iter,
                                   'eval_eprewmean':eval_temp,
                                   'train_eprewmean':train_subset_tasks_eval,
                                   'episode_num':episode_num,
                                   'sampling_loop':sampling_loop})

            #######
            # save for every interval-th episode or for the last epoch
            #######
            if (episode_num % args.save_freq == 0 or episode_num == args.total_timesteps - 1):
                    take_snapshot(args, ck_fname_part, alg, update_iter)

            #######
            # Interact and collect data until reset
            #######
            # should reset the queue, as new trail starts
            epinfobuf = deque(maxlen=args.n_train_tasks)
            avg_epi_length = 0

            for sl in range(args.num_tasks_sample):

                if sl > 0:
                    idx = np.random.randint(len(train_tasks))
                    tidx = train_tasks[idx]

                if args.enable_promp_envs == True:
                    env.set_task(tidx) # tidx for promp is task value

                else:
                    env.reset_task(tidx) # tidx here is an id

                data = rollouts.run(update_iter, task_id = tidx)
                timesteps_since_eval += data['episode_timesteps']
                update_iter += data['episode_timesteps']
                epinfobuf.extend(data['epinfos'])
                epinfobuf_v2.extend(data['epinfos'])
                episode_num += 1
                avg_epi_length += data['episode_timesteps']

            avg_epi_length = int(avg_epi_length/args.num_tasks_sample)

        # just to keep track of how many times all training tasks have been seen
        sampling_loop += 1

    ###############
    # Eval for the final time
    ###############
    eval_temp = evaluate_policy(env, alg, episode_num, update_iter, etasks=eval_tasks, eparams=args)
    # Eval subset of train tasks
    if args.enable_promp_envs == False:
        train_subset = np.random.choice(train_tasks, len(eval_tasks))

    else:
        train_subset = None

    train_subset_tasks_eval = evaluate_policy(env, alg, episode_num, update_iter,
                                              etasks=train_subset,
                                              eparams=args,
                                              msg ='Train-Eval')

    eval_results.append(eval_temp)
    wrt_csv_eval.write({'nupdates':update_iter,
                       'total_timesteps':update_iter,
                       'eval_eprewmean':eval_temp,
                       'train_eprewmean':train_subset_tasks_eval,
                       'episode_num':episode_num,
                       'sampling_loop':sampling_loop})
    wrt_csv_eval.close()
    print('Done')

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值