gym_render_by_pygame

21 篇文章 4 订阅
5 篇文章 0 订阅

render gym by pygame and save to local

通过pygame渲染gym,并且将图片保存为视频到本地。

代码有问题,无法取消创建gym自带的渲染窗口,暂时宣布放弃这套方案。
等我找到别人成熟的方案再说。。。

前言:

gym自带的渲染虽然可以调整视角,大小什么的,但是不方便保存为视频,并且有时候我们并不需要这么完整的视角。

对于图像输入的强化学习来说,如果要分析为什么智能体学不到东西,就必须要看看智能体到底在怎么瞎折腾,因此,我们需要一个即时渲染,且能保存视频的脚本。

综合来看,pygame的即时渲染方案是最佳的,pygame库基本没什么乱七八糟的依赖,刷新一次的时间消耗非常短,比cv2和matplotlib好多了,且可以方便的加一些配套信息,比如epoch,step,reward信息,我看到好几个代码库替代渲染方案也是基于pygame。

下面贴一下当初设计的功能:

功能描述:
1.传入特定时刻的env,渲染出RGB图,可以选择,是否将其保存为一个视频流
2.需要用pygame可视化当前图
3.不需要pygame乱七八糟的功能
4.视频保存路径和当前实验log路径一致
5.视频名称需要标注好epoch

代码:

"""
功能描述:
1.传入特定时刻的env,渲染出RGB图,可以选择,是否将其保存为一个小视频
2.需要用pygame可视化当前图
3.不需要pygame乱七八糟的功能
4.视频保存路径和当前实验log路径一致
5.视频名称需要标注好epoch

"""


import pygame
import os
from pygame.locals import *
from sys import exit
import numpy as np
import cv2
import imutils


class GymRenderImageSaveVideoClass:
    def __init__(self,
                 exp_path='',
                 fps=20,
                 render_in_gym=False,
                 save_flag=True,
                 args=None,
                 ):
        self.render_in_gym = render_in_gym
        self.args = args        

        # pygame init
        self.start_epoch = 0
        self.image_width = args.image_width
        self.image_height = args.image_height
        self.font_size = args.image_width // 15

        pygame.init()
        self.break_flag = False
        self.image_count = 0
        self.position_font = pygame.font.SysFont("幼圆", self.font_size)
        # 返回一个窗口Surface对象
        self.screen = pygame.display.set_mode((self.image_width,
                                               self.image_height),
                                               0, 32)
        self.exp_name = args.exp_name

        # save video:
        self.save_flag = save_flag
        if self.save_flag:
            self.exp_path = exp_path
            self.max_store = args.max_steps-1
            try:
                os.mkdir(self.exp_path)
            except Exception as e:
                print(e)
            self.fourcc = cv2.VideoWriter_fourcc(*'XVID')  # 保存视频的编码
            self.fps = fps

    def update(self, image, epoch=0, step=0, reward=0):
        # 判断是否退出
        for event in pygame.event.get():
            if event.type == QUIT:
                self.break_flag = True
        # 在窗口标题上显示参数元组
        pygame.display.set_caption(self.exp_name)
        robot_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        # 图片可能需要旋转
        robot_image = imutils.rotate(robot_image, 90, )        
        # 图片垂直翻转
        robot_image = cv2.flip(robot_image, 0)

        robot_image = cv2.resize(robot_image,
                                 (self.image_height,
                                  self.image_width))

        my_surface = pygame.pixelcopy.make_surface(robot_image)
        self.screen.blit(my_surface, (0, 0))
        pygame_str = 'epoch:' + str(epoch) + '-step:' + str(step) + '-rew:' + str(np.round(reward, 3))
        text1 = self.position_font.render(pygame_str, True,
                                          (255, 0, 0), (0, 0, 0))
        self.screen.blit(text1,
                         [10, 40],
                         )

        pygame.display.update()
        return robot_image

    def step(self, epoch='', step='', reward=0.0,
             env=None, specfic_name='', per_epoch=False,
             img=None,
             ):    
        if self.image_count == 0 and self.save_flag:
            video_name = self.args.exp_name + '_st_ep' + str(int(epoch))
            video_path = os.path.join(self.exp_path, specfic_name+'_'+video_name + '.avi')
            self.out = cv2.VideoWriter(video_path, self.fourcc, self.fps,
                                       (self.image_width,
                                        self.image_height))
        if env is None:
            robot_image = np.uint8(np.random.random((self.image_width, self.image_height, 3)))
        else:
            if img is not None:
                robot_image = np.flip(img, 0)
                robot_image = cv2.cvtColor(robot_image, cv2.COLOR_RGB2BGR)
            else:       
                # TODO 如何只获取env.render()的图片,不创建窗口?
                # env.unwrapped.render()
                if self.render_in_gym:
                    env.render('human')
                # robot_image = np.uint8(np.random.random((self.image_width, self.image_height, 3)))
                try:
                    robot_image = env.render("rgb_array",
                                              width=self.image_width,
                                              height=self.image_height,
                                              )
                except:                    
                    robot_image = env.render("rgb_array",
                    )                                             

                # 判断是否退出
                for event in pygame.event.get():
                    if event.type == QUIT:
                        self.break_flag = True
                if self.break_flag:
                    pygame.quit()
                    import sys
                    print("exit")
                    sys.exit()
                    
                robot_image = cv2.cvtColor(robot_image, cv2.COLOR_RGB2BGR)

        self.update(image=robot_image,
                    epoch=epoch,
                    step=step,
                    reward=reward
                    )
        if self.save_flag:
            # save img to local as video stream, and add some info to img.
            post_img = cv2.resize(robot_image,
                                (self.image_width,
                                self.image_height))
            fontScale = 1
            text_thickness = 1
            fontSize = 1
            bg_color = (255, 0, 0)
            fontFace = cv2.FONT_HERSHEY_SIMPLEX
            cv2.putText(post_img, 'epoch:' + str(epoch),
                        (10, 20), fontFace, fontScale,
                        bg_color, text_thickness, fontSize)
            cv2.putText(post_img, 'step:' + str(step),
                        (10, 40), fontFace, fontScale, 
                        bg_color, text_thickness, fontSize
                        )
            cv2.putText(post_img, 'rew:' + str(np.round(reward, 3)),
                        (10, 60), fontFace, fontScale,
                        bg_color, text_thickness, fontSize)
            # cv2.imshow("post_img:", post_img)
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()
            post_img = np.uint8(post_img)

            self.out.write(post_img)
            self.image_count += 1
            if per_epoch:
                if step == self.max_store - 1:
                    self.image_count = 0
                    self.out.release()


def main():
    import time
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_width', type=int, default=256)
    parser.add_argument('--image_height', type=int, default=256)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--gamma', type=float, default=0.9)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--max_steps', type=int, default=20)
    parser.add_argument('--exp_name', type=str, default='gym_render_by_pygame')
    parser.add_argument('--output_dir', type=str, default='Gym_render_by_pygame')

    args = parser.parse_args()
    logger_kwargs = {'exp_name': args.exp_name,
                     'output_dir': args.output_dir}
    dpc = GymRenderImageSaveVideoClass(exp_path=logger_kwargs['output_dir'],
                                       save_flag=True,
                                       render_in_gym=False,
                                       args=args,
                                       )
    import matplotlib.pyplot as plt
    st_list = []
    import gym
    from gym.wrappers import Monitor
    def wrap_env(env):
        env = Monitor(env, './video', force=True)
        return env
    env = gym.make("MountainCar-v0",)
    env = wrap_env(env)

    for i in range(args.epochs):
        env.reset()
        for j in range(args.max_steps):
            st = time.time()
            obs, r, done, info = env.step(env.action_space.sample())
            dpc.step(epoch=i, step=j, reward=r,
                     env=env, specfic_name='train',
                     per_epoch=False)
            
            # if dpc.break_flag:
            #     pygame.quit()
            #     dpc.out.release()
            #     cv2.destroyAllWindows()
            #     exit()
    dpc.out.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()



  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

hehedadaq

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值