MAML-RL Pytorch 代码解读 (11) – maml_rl/utils/optimization.py、reinforcement_learning.py和torch_utils.py
文章目录
基本介绍
在网上看到的元学习 MAML 的代码大多是跟图像相关的,强化学习这边的代码比较少。
因为自己的思路跟 MAML-RL 相关,所以打算读一些源码。
MAML 的原始代码是基于 tensorflow 的,在 Github 上找到了基于 Pytorch 源码包,学习这个包。
源码链接
https://github.com/dragen1860/MAML-Pytorch-RL
文件路径
./maml_rl/utils/
optimization.py
文件
conjugategradient()
函数
import torch
def conjugate_gradient(f_Ax, b, cg_iters=10, residual_tol=1e-10):
#### 这个函数是共额梯度法。用共额梯度法求解一个方程。
#### p:拷贝b参数,并断开求导链;r:拷贝b参数,并断开求导链;x:拷贝b参数,用全0赋值,并断开求导链;rdotr:变量r的平方。
p = b.clone().detach()
r = b.clone().detach()
x = torch.zeros_like(b).float()
rdotr = torch.dot(r, r)
for i in range(cg_iters):
#### z = A x p ,A 是对称正定矩阵
z = f_Ax(p).detach()
#### v相当于求出了直线搜索的步长,标记为 alpha_k
v = rdotr / torch.dot(p, z)
#### x自变量的更新,步长乘以当前的方向
x += v * p
#### r相当于中间变量r_k,标记为 alpha_k x z = alpha_k x A x p
r -= v * z
#### newrdotr表示新的轭模长的平方
newrdotr = torch.dot(r, r)
#### mu用于更新下一个共轭梯度的轭
mu = newrdotr / rdotr
p = r + mu * p
rdotr = newrdotr
if rdotr.item() < residual_tol:
break
return x.detach()
reinforcement_learning.py
文件
import numpy as np
#### 经典强化学习状态价值迭代的代码
def value_iteration(transitions, rewards, gamma=0.95, theta=1e-5):
rewards = np.expand_dims(rewards, axis=2)
values = np.zeros(transitions.shape[0], dtype=np.float32)
delta = np.inf
while delta >= theta:
#### transitions的维度是[Ns x Na x Ns],多维数组乘以一维数组,一维数组的维度等于多维数组最后一维的维度时,可以相乘,乘出来的结果还是[Ns x Na x Ns]维度,相当于一维的元素分别乘以多维数组中该维度对应切片数组的乘积。这个代码得到了[状态转移x(奖励+折扣x价值)]的矩阵,相当于得到了这个状态转移过程的价值,但是我们求的是Q值,所以就对最高维度求和,得到q_values矩阵。
q_values = np.sum(transitions * (rewards + gamma * values), axis=2)
#### 选取每个状态中对动作价值最大的作为新的状态价值
new_values = np.max(q_values, axis=1)
#### 计算迭代误差,并更新每个状态的价值,从而完成一次值迭代
delta = np.max(np.abs(new_values - values))
values = new_values
return values
#### 有限视野(步长)下强化学习价值迭代代码,实际上是在value_iteration()函数上迭代有限次数。
def value_iteration_finite_horizon(transitions, rewards, horizon=10, gamma=0.95):
rewards = np.expand_dims(rewards, axis=2)
values = np.zeros(transitions.shape[0], dtype=np.float32)
for k in range(horizon):
#### transitions的维度是[Ns x Na x Ns],多维数组乘以一维数组,一维数组的维度等于多维数组最后一维的维度时,可以相乘,乘出来的结果还是[Ns x Na x Ns]维度,相当于一维的元素分别乘以多维数组中该维度对应切片数组的乘积。这个代码得到了[状态转移x(奖励+折扣x价值)]的矩阵,相当于得到了这个状态转移过程的价值,但是我们求的是Q值,所以就对最高维度求和,得到q_values矩阵。
q_values = np.sum(transitions * (rewards + gamma * values), axis=2)
#### 选取每个状态中对动作价值最大的作为新的状态价值
values = np.max(q_values, axis=1)
return values
torch_utils.py
文件
import torch
from torch.distributions import Categorical, Normal
#### 这个函数应该是计算加权平均值,并做一些处理
def weighted_mean(tensor, dim=None, weights=None):
#### 没有权值的时候,输出标准的平均值张量
if weights is None:
return torch.mean(tensor)
#### 没有维度的时候,此时有权值,计算归一化权值
if dim is None:
sum_weights = torch.sum(weights)
return torch.sum(tensor * weights) / sum_weights
#### 既有维度又有权值的时候,要按照每个维度加权求平均值
#### 如果dim是整数的时候,将dim进行元组化,便于迭代
if isinstance(dim, int):
dim = (dim,)
#### 分子是张量乘以权重,分母只是权重。
#### keepdim=True保持这些维度不会被挤掉。
#### 对每个维度求和和权重最后做除法。
numerator = tensor * weights
denominator = weights
for dimension in dim:
numerator = torch.sum(numerator, dimension, keepdim=True)
denominator = torch.sum(denominator, dimension, keepdim=True)
return numerator / denominator
#### 这个函数对分布做一些处理。如果是离散的Categorical分布,则根据pi的logits数值生成一个分布;如果是连续的Normal分布,则根据pi的loc定均值,pi的scale定方差。如果不是这两个分布,那么输出异常。
def detach_distribution(pi):
if isinstance(pi, Categorical):
distribution = Categorical(logits=pi.logits.detach())
elif isinstance(pi, Normal):
distribution = Normal(loc=pi.loc.detach(), scale=pi.scale.detach())
else:
raise NotImplementedError('Only `Categorical` and `Normal` '
'policies are valid policies.')
return distribution
def weighted_normalize(tensor, dim=None, weights=None, epsilon=1e-8):
#### 没有权重,生成一个和张量大小一样的全1权重。
if weights is None:
weights = torch.ones_like(tensor)
#### 根据上面的函数weighted_mean(不同维度的加权均值,求出总体的均值)
mean = weighted_mean(tensor, dim=dim, weights=weights)
#### 中心化 = ( tensor * weights - 均值 )- 根号下的(tensor * weights - 均值)的平方。为了避免0做分母产生异常,用一个很小的数加上去做替代。
centered = tensor * weights - mean
std = torch.sqrt(weighted_mean(centered ** 2, dim=dim, weights=weights))
return centered / (std + epsilon)