强化学习 gridworld P77 模拟

import numpy as numpy
from tabulate import tabulate

class State: # class for each unit on the game
    def __init__(self, _id):
        if _id != 0:
            self.value = 0
        elif _id == 0:
            self.value = 0
        self.id = _id
        self.left_bound = max(1, (self.id // 4) * 4)  # save the left border
        self.right_bound =  min(14, (self.id // 4) * 4 + 3) # save the right border
        self.nextS = [self.move('L'), self.move('R'), self.move('U'), self.move('D')]


    def move(self, u):  # get next S
        if u == 'L':  # move left
            if self.id - 1 >= self.left_bound:
                return self.id - 1
            elif self.id - 1 == 0:
                return 0
            else:
                return self.id
        if u == 'R':  # move right
            if self.id + 1 <= self.right_bound:
                return self.id + 1
            elif self.id + 1 == 15:
                return 0
            else:
                return self.id
        if u == 'U':  # move up
            if self.id - 4 >= 1:
                return self.id - 4
            elif self.id - 4 == 0:
                return 0
            else:
                return self.id
        if u == 'D':  # move down
            if self.id + 4 <= 14:
                return self.id + 4
            elif self.id + 4 == 15:
                return 0
            else:
                return self.id

    def update(self, S):  # S is the whole set of the States.
        V = 0
        for i in range(0, 4):
            V += S[self.nextS[i]].value
        # print("update id " + str(self.id))
        # print("for the " + str(self.nextS[i]))
        # print("old value" + str(self.value))
        # print("new value" + str(S[self.nextS[i]].value))
        self.value = -1+0.25 * V

def train(k=10):
    V  = []
    S_T = State(0)
    S = {0: S_T}
    for j in range(1, 15):
        S[j] = State(j)
    for loop in range(k):
        if loop>=1000 and loop%1000 == 0:
            print("Training "+str(loop)+"'s loop.......Remaining: "+str(k-loop)+ " loops")
        n = numpy.random.random()
        if n > 0.5:
            for j in range(1, 15):
                S[j].update(S)
        else:
            for j in range(14, 0, -1):
                S[j].update(S)
    for t in range(0,16):
        if t == 0 or t == 15:
            V.append("0")
        else:
            V.append(S[t].value)
    draw(V)


def draw(valueArray):
    for i in range(4):
        print("----------------------")
#         print("| "+str(int(valueArray[i*4]))+" | "+str(int(valueArray[i*4+1])) +" | "+str(int(valueArray[i*4+2])) +" | "+str(int(valueArray[i*4+3])) +" |")
#         print("| %.1f | %.1f | %.1f | %.1f |" % float(valueArray[i*4]),valueArray[i*4+1],valueArray[i*4+2],valueArray[i*4+3])
#         print("| %.1f | %.1f | %.1f | %.1f |" % float(valueArray[i*4]),float(valueArray[i*4+1]),float(valueArray[i*4+2]),float(valueArray[i*4+3]))
        print("| %.1f | %.1f | %.1f | %.1f " % (float(valueArray[i*4]),float(valueArray[i*4+1]),float(valueArray[i*4+2]),float(valueArray[i*4+3])))
    print("----------------------")
    print("Accurate State Values List:")
    for i in range(1,8):
        print("State "+str(2*i-1)+": "+str(valueArray[2*i-1])+ "          State "+str(2*i)+": "+str(valueArray[2*i]))

if __name__ == '__main__':
    k = input("Specify the desired training loop count(0-10000):")
    train(int(k))

增加一个节点 exercise 4.3


import numpy as numpy
from tabulate import tabulate

class State: # class for each unit on the game
    def __init__(self, _id):
        if _id != 0:
            self.value = 0
        elif _id == 0:
            self.value = 0
        self.id = _id
        self.left_bound = max(1, (self.id // 4) * 4)  # save the left border
        self.right_bound =  min(14, (self.id // 4) * 4 + 3) # save the right border
        self.nextS = [self.move('L'), self.move('R'), self.move('U'), self.move('D')]


    def move(self, u):  # get next S
        if self.id==17:
            if u=='L':
                return 12
            if u=='R':
                return 14
            if u=='U':
                return 13
            if u=='D':
                return 17
        if self.id==13 and u=='D':
                return 17
        if u == 'L':  # move left
            if self.id - 1 >= self.left_bound:
                return self.id - 1
            elif self.id - 1 == 0:
                return 0
            else:
                return self.id
        if u == 'R':  # move right
            if self.id + 1 <= self.right_bound:
                return self.id + 1
            elif self.id + 1 == 15:
                return 0
            else:
                return self.id
        if u == 'U':  # move up
            if self.id - 4 >= 1:
                return self.id - 4
            elif self.id - 4 == 0:
                return 0
            else:
                return self.id
        if u == 'D':  # move down
            if self.id + 4 <= 14:
                return self.id + 4
            elif self.id + 4 == 15:
                return 0
            else:
                return self.id

    def update(self, S):  # S is the whole set of the States.
        V = 0
        for i in range(0, 4):
            V += S[self.nextS[i]].value
        # print("update id " + str(self.id))
        # print("for the " + str(self.nextS[i]))
        # print("old value" + str(self.value))
        # print("new value" + str(S[self.nextS[i]].value))
        self.value = -1+0.25 * V

def train(k=10):
    V  = []
    S_T = State(0)
    S = {0: S_T}
    for j in range(1, 15):
        S[j] = State(j)
    S[17] = State(17)
    for loop in range(k):
        if loop>=1000 and loop%1000 == 0:
            print("Training "+str(loop)+"'s loop.......Remaining: "+str(k-loop)+ " loops")
        n = numpy.random.random()
        if n > 0.5:
            for j in range(1, 15):
                S[j].update(S)
            S[17].update(S)
        else:
            S[17].update(S)
            for j in range(14, 0, -1):
                S[j].update(S)
    for t in range(0,16):
        if t == 0 or t == 15:
            V.append("0")
        else:
            V.append(S[t].value)
    V.append(S[17].value)
    draw(V)


def draw(valueArray):
    for i in range(4):
        print("----------------------")
#         print("| "+str(int(valueArray[i*4]))+" | "+str(int(valueArray[i*4+1])) +" | "+str(int(valueArray[i*4+2])) +" | "+str(int(valueArray[i*4+3])) +" |")
#         print("| %.1f | %.1f | %.1f | %.1f |" % float(valueArray[i*4]),valueArray[i*4+1],valueArray[i*4+2],valueArray[i*4+3])
#         print("| %.1f | %.1f | %.1f | %.1f |" % float(valueArray[i*4]),float(valueArray[i*4+1]),float(valueArray[i*4+2]),float(valueArray[i*4+3]))
        print("| %.1f | %.1f | %.1f | %.1f " % (float(valueArray[i*4]),float(valueArray[i*4+1]),float(valueArray[i*4+2]),float(valueArray[i*4+3])))
    print(float(valueArray[16]))
    print("----------------------")
    print("Accurate State Values List:")
    for i in range(1,8):
        print("State "+str(2*i-1)+": "+str(valueArray[2*i-1])+ "          State "+str(2*i)+": "+str(valueArray[2*i]))

if __name__ == '__main__':
    k = input("Specify the desired training loop count(0-10000):")
    train(int(k))

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值