纯新手入门80行完成DQN CartPole
在b站看到的对新手很友好的视频,下面也有代码,我详细地加上了一些自己的见解,没有繁琐的各种类套娃,也没有很多花里胡哨的新手看一个懵一个数据处理。
教学视频.
import torch
import torch.nn as nn
import gym
import numpy as np
import random
class MyNet(nn.Module):
def __init__(self):
super(MyNet,self).__init__()
#定义一个神经网络,输入为四个state,输出为2个action
self.fc=nn.Sequential(
nn.Linear(4,24),
nn.ReLU(),
nn.Linear(24,24),
nn.ReLU(),
nn.Linear(24,2))
#损失函数
self.mls=nn.MSELoss()
#优化
self.opt=torch.optim.Adam(self.parameters(),lr=0.01)
def forward(self,x