强化学习的stable_baselines3/common/torch_layers有哪些特征提取器?

强化学习的stable_baselines3/common/torch_layers有哪些特征提取器

BaseFeaturesExtractor()类

最基本的特征提取器,所有其他特征提取器的基类,继承自nn.Module。有两个属性:观察空间和特征维数;有两个方法:features_dim()方法输出特征维数,forward方法用于加工。

FlattenExtractor()类

继承自BaseFeaturesExtractor,有forward方法,所做的操作就是将observation展平,这样的话便于网络输入训练。

NatureCNN()类

继承自BaseFeaturesExtractor()类,主要是为图像准备的,通过CNN和FC提取特征。(中间通过一次不计算梯度的前向传播来记录CNN 输出的维数),一个针对于观测的小小的网络。

MlpExtractor()类

继承自BaseFeaturesExtractor()类,构建一个多层感知机网络,它接受观测作为输入,为策略和价值网络输出一个表示。主要是利用全连接网络来提取特征。

CombinedExtractor()类

继承自BaseFeatureExtractor()类,是针对字典类型的observation_space的特征提取器。主要做法是:为Dict observation space空间的每一个子空间分别采用不同的特征提取器进行特征提取,再将最后得到的向量进行拼接,可能还会再输入到额外的MLP特征提取器中进行特征提取。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值