- env.py
import sys
import json
import torch
import numpy as np
import argparse
import torchvision.transforms as transforms
import cv2
from DRL.ddpg import decode
from utils.util import *
from PIL import Image
from torchvision import transforms, utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
aug = transforms.Compose(
[transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
])
width = 128
convas_area = width * width
img_train = []
img_test = []
train_num = 0
test_num = 0
class Paint:
def __init__(self, batch_size, max_step):
self.batch_size = batch_size
self.max_step = max_step
self.action_space = (13)
self.observation_space = (self.batch_size, width, width, 7)
self.test = False
def load_data(self):
global train_num, test_num
for i in range(200000):
img