OpenPose -tensorflow代码解析(2)—— 多进程数据读取 dataset.py

前言

该openpose-tensorflow的工程是自己实现的,所以有些地方写的会比较简单,但阅读性强、方便使用。

论文翻译 || openpose – Realtime Multi-Person 2D Pose Estimation using Part Affinity Fields
工程实现 || 基于opencv使用openpose完成人体姿态估计

OpenPose -tensorflow代码解析(1)——工程概述&&训练前的准备
OpenPose -tensorflow代码解析(2)—— 数据增强和处理 dataset.py
OpenPose -tensorflow代码解析(3)—— 网络结构的搭建 Net.py
OpenPose -tensorflow代码解析(4)—— 训练脚本 train.py
OpenPose -tensorflow代码解析(5)—— 预测代码解析 predict.py

1 代码概述

将openpose的数据读取定义称一个类 class dataset

里面主要实现了3个功能:

  • 图片的数据增强:随机翻转、随机裁剪、旋转平移、色彩的增强
  • 将关键点转化成 openpose的网络的输出形式:关键点的热量图、亲和域的热量图
  • 将增强好的image、转化好的label 使用多进程放进队列中。
    这是让cpu提前处理数据放入队列中,当训练时,GPU完成一次反向传播后,能够直接获取到新的数据进行新一次的迭代,避免GPU处于空闲,减少总共训练时间

2 初始化

初始化部分,就是读取配置文件(配置脚本 opt.py附在最后),已经设定一些遍变量。其中需要说明的是:

  • self.point_num = cfg.OP.cpm_num 设置的是关键点的数量,不包含背景
  • self.paf_num = cfg.OP.paf_num 设置的是关键点之间的连接数量*2
  • self.shuffle_ref 设置的是关键点的具体连接方式
  • self.LR_morrir 如果关键点是对称的,对于关键点[0,1,2,3,4,5],假设它们对应的镜像点的索引为 [5,4,3,2,1,0]。如果关键点不是对称的,这里为空
  • self.Q_name = Queue(self.num_samples) 读取验证数据的名字的队列
  • self.Q_data = Queue(1000) 读取训练数据的队列

    这里的队列长度的设置是有讲究的。因为 Dataset 会被实例化成 【trainset、testset】,所以一定要保证当单个队列 Q_data 满数据时,不能占满整个内存,大约50%即可,否则可能发生另外一个队列无法放置数据
 class Dataset:
   def __init__(self, dataset_type):

       self.fortest = False

       self.annot_path  = cfg.TRAIN.annot_path if dataset_type == 'train' else cfg.TEST.annot_path
       if not os.path.exists(self.annot_path):
           print(self.annot_path+" 文件不存在!")
           exit()
       self.input_sizes = cfg.TRAIN.input_size if dataset_type == 'train' else cfg.TEST.input_size
       self.batch_size  = cfg.TRAIN.batch_size if dataset_type == 'train' else cfg.TEST.batch_size
       self.data_aug    = cfg.TRAIN.data_aug   if dataset_type == 'train' else cfg.TEST.data_aug

       self.WH_ratio = cfg.OP.WH_ratio
       self.stride = cfg.OP.strides
       self.point_num = cfg.OP.cpm_num
       self.paf_num = cfg.OP.paf_num
       self.shuffle_ref = [[0, 1], [1, 2], [2, 3], [3, 4],
                           [0, 5], [5, 6], [6, 7], [7, 8],
                           [0, 9], [9, 10], [10, 11], [11, 12],
                           [0, 13], [13, 14], [14, 15], [15, 16],
                           [0, 17], [17, 18], [18, 19], [19, 20],
                           [0, 21]]
      self.LR_morrir = []
      self.sigma = 0.8

       self.annotations = self.load_annotations()
       self.num_samples = len(self.annotations)    # 样本的数量
       self.num_step_one_epoch = int(np.ceil(self.num_samples / self.batch_size))  # 一轮读取批数
       self.batch_count = 0

       self.num_trianepoch = cfg.TRAIN.first_stage_epoch + cfg.TRAIN.second_stage_epoch

       self.Q_name = Queue(self.num_samples)  # 读取验证数据的名字的队列
       self.Q_data = Queue(1000)  # 读取训练数据的队列


然后

  • 定义 len(dataset) = num_step_one_epoch,也就是一轮训练的步数
  • 加载 train.txt 或者 test.txt 文件,获取到数据集的 image-label 的路径
  • 根据batch 的大小,设定好一批数据的numpy 数组。
    其中值得注意的是,self.input_size 的计算。当我们想要设置多尺度图片进行训练,只需要在opt.py 文件中,设置多个尺寸,这里就会每个batch的数据,随机获取一个尺寸的大小进行处理数据。所以 Prepare() 函数,是要在每个batch都进行调用,就不能放在 __init__() 中。
   def __len__(self):
       return self.num_step_one_epoch

   def load_annotations(self,):
       with open(self.annot_path, 'r') as f:
           txt = f.readlines()
           annotations = [line.split() for line in txt ]
       return annotations

   def Prepare(self):
       size = random.choice(self.input_sizes)

       self.input_size = [size, int(size//self.WH_ratio)]
       self.output_size = [int(self.input_size[0]//self.stride), int(self.input_size[1]//self.stride)]

       self.batch_image = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], 3), dtype=float)
       self.batch_label_heatmap = np.zeros((
           self.batch_size, self.output_size[0], self.output_size[1], self.point_num+1), dtype=float)
       self.batch_label_vectmap = np.zeros((
           self.batch_size, self.output_size[0], self.output_size[1], len(self.shuffle_ref)*2), dtype=float)

2 数据增强

数据增强:色彩增强、随机翻转、随机旋转、随机平移
数据处理:给图片填充和缩放,图片的内容保持原本的长宽比例

   def load_data(self, image_path, label_path):

       if not os.path.exists(image_path):
           print(image_path+" 图片不存在")
           raise KeyError("%s does not exist ... " %image_path)
       
       image = np.array(cv2.imread(image_path))
       joint = np.loadtxt(label_path)

       show_image("image_or", image)  if self.fortest else None

       if self.data_aug:

           image = self.change_img(image)
           show_image("change_img", image) if self.fortest else None

           image,joint = self.random_horizontal_flip(image, joint)
           show_image("random_horizontal_flip", image) if self.fortest else None

           image,joint = self.random_horizontal_rotation(image, joint)
           show_image("random_horizontal_rotation", image) if self.fortest else None

           image, joint = self.random_translate(image, joint)
           show_image("random_horizontal_flip", image) if self.fortest else None

       image, joint = image_preporcess(image, self.input_size, joint)
       show_image("random_horizontal_flip", image) if self.fortest else None

       return image, joint

def show_image(name, image):
   cv2.namedWindow(name, 0)  # 0 窗口可伸缩
   cv2.resizeWindow(name, 500, 500)  # 初始窗口大小
   cv2.imshow(name, image)  # 展示图片
   cv2.waitKey(0)  # 保持展示
   # cv2.destroyAllWindows()  # 注销窗口
   

2.1 色彩增强

色彩的改变,在 Pillow 库中,有很方便的api


   def change_img(self,img):

       p = random.randint(0, 3)
       a1 = random.uniform(0.8, 2)
       a2 = random.uniform(0.8, 1.4)
       a3 = random.uniform(0.8, 1.7)
       a4 = random.uniform(0.8, 2.5)
       img = Image.fromarray(img)

       img = ImageEnhance.Color(img).enhance(a1) if p == 0 else img
       img = ImageEnhance.Brightness(img).enhance(a2) if p == 1 else img
       img = ImageEnhance.Contrast(img).enhance(a3) if p == 2 else img
       img = ImageEnhance.Sharpness(img).enhance(a4) if p == 3 else img
       img = np.array(img)

       return img
       


在这里插入图片描述

2.2 随机水平翻转

如果标注的关节点是镜像的,如人体的关节点,在做水平翻转时,主要关键点的位置和索引,都要进行镜像处理,也就是 joint = joint[self.LR_morrir,:]

   def random_horizontal_flip(self, image, joint):
       if random.random() < 0.5:
           _, w, _ = image.shape
           image = image[:, ::-1, :]
           joint[:, 0] = w - joint[:, 0]
           # joint = joint[self.LR_morrir,:]
       return image, joint
       


在这里插入图片描述

2.3 随机旋转

思路:
根据随机获取的角度值,得到相应的旋转矩阵;
用这个旋转矩阵,以及opencv中的api,对图片进行旋转;
用这个旋转矩阵,对关键点进行相应的旋转。


   def random_horizontal_rotation(self, image, joint):
       if random.random() < 0.7:
           # 设置旋转矩阵
           transform_matrix = affine_rotation_matrix(angle=(-10,10), x=self.input_size[1]//2, y=self.input_size[0]//2)
           # 使用旋转矩阵旋转图片
           image = affine_transform_cv2(image, transform_matrix)
           #  使用旋转矩阵旋转关键点
           joint = affine_transform_keypoints(joint, transform_matrix)
       return image, joint

# 设置旋转矩阵
def affine_rotation_matrix(angle, x, y):

   if isinstance(angle, tuple):
       theta = np.pi / 180 * np.random.uniform(angle[0], angle[1])
   else:
       theta = np.pi / 180 * angle
   rotation_matrix = np.array([[np.cos(theta), np.sin(theta), 0],
                               [-np.sin(theta), np.cos(theta), 0],
                               [0, 0, 1]])
   o_x = (x - 1) / 2.0
   o_y = (y - 1) / 2.0
   offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
   reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
   transform_matrix = np.dot(np.dot(offset_matrix, rotation_matrix), reset_matrix)
   return transform_matrix

# 使用旋转矩阵旋转图片
def affine_transform_cv2(x, transform_matrix, flags=None, border_mode='constant'):

   rows, cols = x.shape[0], x.shape[1]
   if flags is None:
       flags = cv2.INTER_AREA
   if border_mode is 'constant':
       border_mode = cv2.BORDER_CONSTANT
   elif border_mode is 'replicate':
       border_mode = cv2.BORDER_REPLICATE
   else:
       raise Exception("unsupport border_mode, check cv.BORDER_ for more details.")
   return cv2.warpAffine(x, transform_matrix[0:2, :], (cols, rows), flags=flags, borderMode=border_mode)

#  使用旋转矩阵旋转关键点
def affine_transform_keypoints(coords_list, transform_matrix):

   coords = coords_list.transpose([1, 0])
   coords = np.insert(coords, 2, 1, axis=0)

   coords_result = np.matmul(transform_matrix, coords)
   coords_result = coords_result[0:2, :].transpose([1, 0])
   return coords_result


在这里插入图片描述

2.4 随机平移

进行随机平移的操作,一定要保证不能将标签所在区域 平移超出图片的范围。
所以需要先计算关键点的最小凸集,然后用这个参数,来设定平移的范围。


   def random_translate(self, image, joint):

       if random.random() < 0.5:
           h, w, _ = image.shape

           # 求图片中所有点的最小凸集框的左上角和右下角
           max_bbox = np.concatenate([np.min(joint, axis=0), np.max(joint, axis=0)], axis=-1)

           # 获取最小凸集与图片的最上角的距离
           max_l_trans = max_bbox[0]
           max_u_trans = max_bbox[1]
           max_r_trans = w - max_bbox[2]
           max_d_trans = h - max_bbox[3]

           tx = random.uniform(-(max_l_trans - 1), (max_r_trans - 1))
           ty = random.uniform(-(max_u_trans - 1), (max_d_trans - 1))

           M = np.array([[1, 0, tx], [0, 1, ty]])
           image = cv2.warpAffine(image, M, (w, h))

           joint = joint + np.array([tx,ty])

       return image, joint


在这里插入图片描述

2.5 数据尺寸处理

我们需要将图片处理成 神经网络输入的尺寸。
原则是,填充短边 使长宽比例与神经网络输入长款比例一样,然后再进行缩放,保证图片没有被拉伸或压缩。具体实现的方式很多种,只要实现没有改变长款比例就行。

def image_preporcess(image, target_size, joint=None):

   # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)

   ih, iw    = target_size
   h,  w, _  = image.shape

   scale = min(iw/w, ih/h)
   nw, nh  = int(scale * w), int(scale * h)
   image_resized = cv2.resize(image, (nw, nh))

   image_paded = np.full(shape=[ih, iw, 3], fill_value=128, dtype=np.uint8)
   dw, dh = (iw - nw) // 2, (ih-nh) // 2
   image_paded[dh:nh+dh, dw:nw+dw, :] = image_resized
   # image_paded = image_paded / 255.


   if joint is None:
       return image_paded
   else:
       joint = joint * scale + + np.array([dw,dh])
       return image_paded, joint

3 关键点转化成热量图

3.1 生成关键点的热量图 heatmap

  • 生成热量图的数量为 num_keypoint + 1(背景类)。
  • 对于每一个索引的关键点,都生成一张热量图。关键点的坐标在热量图中相应的位置,会生成一个二元正态分布的数据。该函数实现的方式比较多
  • 背景类的热量图的数值,存在关键点的位置的像素值为0,其他为1。

   def get_heatmap(self, joint_list,sign=True):
       # 该函数当中的joint_list,需要的是关节点的坐标,与图片像素的索引的围度数据相反
       heatmap = np.zeros((self.point_num+1, self.output_size[0], self.output_size[1]), dtype=np.float32)
       for idx in range(self.point_num):
           joints = joint_list[idx]
           if joints[0] < 0 or joints[1] < 0:
               continue
           # print("==")
           self.put_heatmap(heatmap, idx, joints, self.sigma)

       heatmap = heatmap.transpose((1, 2, 0))
       heatmap[:, :, -1] = np.clip(1 - np.amax(heatmap, axis=2), 0.0, 1.0)  # background

       return heatmap

   def put_heatmap(self, heatmap, plane_idx, center, sigma):
       center_x, center_y = center
       _, height, width = heatmap.shape[:3]

       th = 4.6052 
       delta = math.sqrt(th * 2)
       x0 = int(max(0, center_x - delta * sigma))
       y0 = int(max(0, center_y - delta * sigma))
       x1 = int(min(width, center_x + delta * sigma))
       y1 = int(min(height, center_y + delta * sigma))
       for y in range(y0, y1):
           for x in range(x0, x1):
               d = (x - center_x) ** 2 + (y - center_y) ** 2
               exp = d / 2.0 / sigma / sigma
               if exp > th:
                   continue
               heatmap[plane_idx][y][x] = max(heatmap[plane_idx][y][x], math.exp(-exp))
               heatmap[plane_idx][y][x] = min(heatmap[plane_idx][y][x], 1.0)
       return heatmap

我们知道 一维的正太分布的公式为 f ( x ) = 1 2 π σ e x p − ( x − μ ) 2 2 ∗ σ f(x)=\frac{1}{\sqrt{2 \pi }\sigma}exp^{-\frac{(x-\mu)^2}{2*\sigma}} f(x)=2π σ1exp2σ(xμ)2
下面的图为二元正太分布示意图
在这里插入图片描述

3.2 生成亲和域的热量图 vectmap

  • 有连接关系的关键点对 n 组,会生成亲和域的热量图 n*2 组。
  • 在关键点对的连线上,一定宽度的像素值,都进行赋值。一张热量图中赋值为点对的方向向量的x分量,一张赋值为方向向量的 y 分量。
  • 多对连接点对 如果存在交叉重叠,那么重叠的位置的像素值,为多对连接点对的分量的平均值。

   def get_vectormap(self, joint_list,sign = True):
       # 该函数当中的joint_list,需要的是关节点的坐标,与图片像素的索引的围度数据相反
       vectormap = np.zeros((len(self.shuffle_ref)*2, self.output_size[0], self.output_size[1]), dtype=np.float32)
       countmap = np.zeros((len(self.shuffle_ref), self.output_size[0], self.output_size[1]), dtype=np.int16)
       for plane_idx, (j_idx1, j_idx2) in enumerate(self.shuffle_ref):
               center_from = joint_list[j_idx1]
               center_to = joint_list[j_idx2]
               # print("ceter from: ", center_from)
               # print("ceter to: ", center_to)
               if center_from[0] < -100 or center_from[1] < -100 or center_to[0] < -100 or center_to[1] < -100:
                   continue
               self.put_vectormap(vectormap, countmap, plane_idx, center_from, center_to)

       vectormap = vectormap.transpose((1, 2, 0))
       nonzeros = np.nonzero(countmap)
       for p, y, x in zip(nonzeros[0], nonzeros[1], nonzeros[2]):
           if countmap[p][y][x] <= 0:
               continue
           vectormap[y][x][p * 2 + 0] /= countmap[p][y][x]
           vectormap[y][x][p * 2 + 1] /= countmap[p][y][x]
       return vectormap.astype(np.float16)

   def put_vectormap(self, vectormap, countmap, plane_idx, center_from, center_to, threshold=1):
       _, height, width = vectormap.shape[:3]

       vec_x = center_to[0] - center_from[0]
       vec_y = center_to[1] - center_from[1]
       min_x = max(0, int(min(center_from[0], center_to[0]) - threshold))
       min_y = max(0, int(min(center_from[1], center_to[1]) - threshold))
       max_x = min(width, int(max(center_from[0], center_to[0]) + threshold))
       max_y = min(height, int(max(center_from[1], center_to[1]) + threshold))

       norm = math.sqrt(vec_x ** 2 + vec_y ** 2)
       if norm == 0:
           return
       vec_x /= norm
       vec_y /= norm
       for y in range(min_y, max_y):
           for x in range(min_x, max_x):
               bec_x = x - center_from[0]
               bec_y = y - center_from[1]
               dist = abs(bec_x * vec_y - bec_y * vec_x)

               if dist > threshold:
                   continue
               countmap[plane_idx][y][x] += 1
               vectormap[plane_idx * 2 + 0][y][x] = vec_x
               vectormap[plane_idx * 2 + 1][y][x] = vec_y

4 将data 多进程放入队列

  • 设置操作1:获取处理后的input、label,组成一个batch,将batch 数据放入到 Q_data 的队列中
  • 设置操作2:获取所有数据路径,打乱后放入 Q_name队列中
  • 设置多进程:多进程进行 操作1/2,

  def readdata(self,  image_path, label_path, num):

       image, joint = self.load_data(image_path, label_path)

       image = image.astype(np.float32)
       image = (image - np.mean(image, axis=(0,1)))/(np.std(image, axis=(0, 1))+1e-8)

       self.batch_image[num,:,:,:] = image
       self.batch_label_heatmap[num,:,:,:] = self.get_heatmap(joint / self.stride)
       self.batch_label_vectmap[num] = self.get_vectormap(joint / self.stride)
       return image, joint

   def Q_getname(self):
       for i in range(self.num_trianepoch):
           if self.data_aug: random.shuffle(self.annotations)
           for j in range(self.num_samples):
               if not os.path.exists(self.annotations[j][0]) or not os.path.exists(self.annotations[j][1]):
                   continue
               self.Q_name.put(self.annotations[j])

   def Q_getData(self, thread):
       self.Prepare()

       name = []
       while 1:
           if self.batch_count < self.num_step_one_epoch:   # 当【读取了几批】小于【一轮总批数】
               num = 0 # 统计批内读取个数
               while num < self.batch_size:        #【批内读取数据个数】小于【一个batch数值】
                   namefile = self.Q_name.get()
                   self.readdata(namefile[0], namefile[1], num)
                   name.append(namefile[0])
                   num += 1
               self.batch_count += 1 # 统计一轮的训练,读取了几个批次
               # print(name)
               # print(thread )
               self.Q_data.put([name, thread,
                                   self.batch_image,
                                    self.batch_label_heatmap,
                                    self.batch_label_vectmap])
               name = []
           else:
               self.batch_count = 0

   def start(self, P1):
       Process(target=self.Q_getname, args=()).start()
       for thread in range(P1):
           Process(target=self.Q_getData, args=(thread,)).start()
       return self.Q_data

5 全面测试数据读取是否正确

当我们编写好了数据读取的脚本,需要进行两方面的测试:

  • case1:单张图:输入图片的数据增强;神经网络输出相应的label的制作
  • case2:多进程的数据读取是否正确:避免出现多进程重复读取等情况
if __name__ == '__main__':
   
   case = 1 # 1:测试单张图片的数据增强 2:测试队列的获取图片的重复性的问题
   
   if case:

       humandata = Dataset("dotest")
       humandata.data_aug = True # 是否进行数据增强
      humandata.fortest = False  # 是否显示过程中每种增强后的图片

       humandata.Prepare()

       for s in range(len(humandata.annotations)):
           # r[s] = "590DSC_0165.png"
           image_path = humandata.annotations[s][0]
           label_path = humandata.annotations[s][1]

           print(image_path)
           print(label_path)
           image, joint = humandata.readdata(image_path, label_path ,0)
           print(joint.shape)

          c1 = []
          for ii in range(len(joint)):
              c1.append((int(joint[ii][0]), int(joint[ii][1])))
          for cc in range(len(joint)):
              cv2.circle(image, c1[cc], 2, (255, 0, 0), thickness=1)
           show_image("random_horizontal_flip", image)
           
           img_heatmap = np.zeros((humandata.output_size[0], humandata.output_size[1], 3))
           for i in range(humandata.batch_label_heatmap.shape[3]-1):
               H =humandata.batch_label_heatmap[0,:,:,i]
               H = np.array([H,H,H]).transpose([1,2,0])
               img_heatmap = img_heatmap + H
               # img_test1 = cv2.resize(H, (humandata.input_size[1], humandata.input_size[0]))
               # img_heatmap1 = cv2.resize(img_heatmap, (humandata.output_size[1], humandata.output_size[0]))
               # cv2.namedWindow('demo5', 0)  # 0 窗口可伸缩
               # cv2.resizeWindow('demo5', 500, 500)  # 初始窗口大小
               # cv2.imshow("demo5", H)  # 展示图片
               # cv2.waitKey(0)  # 保持展示
           cv2.namedWindow('demo5', 0)  # 0 窗口可伸缩
           cv2.resizeWindow('demo5', 500, 500)  # 初始窗口大小
           cv2.imshow("demo5", img_heatmap)  # 展示图片
           cv2.waitKey(0)  # 保持展示


           img_heatmap = np.zeros((humandata.output_size[0], humandata.output_size[1], 3))
           for i in range(humandata.batch_label_vectmap.shape[3]-1):
               H = abs(humandata.batch_label_vectmap[0,:,:,i] * 255)
               H = np.array([H,H,H]).transpose([1,2,0])
               img_heatmap = img_heatmap + H
               # img_test1 = cv2.resize(H, (humandata.input_size[1], humandata.input_size[0]))
               # img_heatmap1 = cv2.resize(img_heatmap, (humandata.input_size[1], humandata.input_size[0]))
               # cv2.namedWindow('demo5', 0)  # 0 窗口可伸缩
               # cv2.resizeWindow('demo5', 500, 500)  # 初始窗口大小
              # cv2.imshow("demo5", H)  # 展示图片
               # cv2.waitKey(0)  # 保持展示
           cv2.namedWindow('demo5', 0)  # 0 窗口可伸缩
           cv2.resizeWindow('demo5', 500, 500)  # 初始窗口大小
           cv2.imshow("demo5", img_heatmap)  # 展示图片
           cv2.waitKey(0)  # 保持展示
           
   else:
       humandata = Dataset("train")
       Q_traindata = humandata.start(3)
       for i in range(10):
           A = Q_traindata.get()
           print(A[1],A[0])      # 打印出队列存储的名字,以及来源的进程 id

在这里插入图片描述

6 附 opt.py 脚本

from easydict import EasyDict as edict
print("read config  ====================================")
cfg                             = edict()
cfg.OP                        = edict()
# Set the class name
cfg.OP.strides                = 8
cfg.OP.WH_ratio               = 1
cfg.OP.cpm_num = 22
cfg.OP.paf_num = 21*2

# Train options
cfg.TRAIN                       = edict()
cfg.TRAIN.annot_path            = "../data/train.txt"
cfg.TRAIN.batch_size            = 8
cfg.TRAIN.input_size            = [512]
# cfg.TRAIN.INPUT_SIZE            = [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
cfg.TRAIN.data_aug              = True
cfg.TRAIN.learn_rate_init       = 1e-4
cfg.TRAIN.learn_rate_end        = 1e-6

cfg.TRAIN.warmup_epoch         = 2
cfg.TRAIN.first_stage_epoch    = 100
cfg.TRAIN.second_stage_epoch   = 30
cfg.TRAIN.initial_weights        = None
cfg.TRAIN.ckpt_path        = "./model/checkpoint0/"
cfg.TRAIN.log_path = './model/log0/'
#

# TEST options
cfg.TEST                        = edict()
cfg.TEST.annot_path             = "../data/test.txt"
cfg.TEST.batch_size             = 8
cfg.TEST.input_size             = [512]
cfg.TEST.data_aug               = False
  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值