在Tensorflow中把Tensor转换为ndarray时,循环中不断调用run或者eval函数,代码运行越来越慢!

问题

  我有一个这样的需求:我目前有一个已经训练好的encoder模型,它的输出是Tensor类型,我想把它转换成ndarray类型。通过查询资料,我发现可以利用sess.run()Tensor转换为ndarray,于是在我的代码里调用sess.run()成功转换了数据类型。
  但是,我这个数据转换在每一次的循环中都会调用,也就是循环中一直调用sess.run(),于是问题来了,每循环一次,sess.run的用时都比上一次要久,导致后面训练越来越慢。从第一次调用用时0.17s到后面第100次调用时0.27s,而且这才是100次,如果训练10000次,那不知道要等多久,所以这个问题必须解决!

问题原因

  如果在某一个循环里不断建立tensorflow图节点再运行的话,会导致tensorflow运行越来越慢。具体问题请看代码注释,没有注释的代码行可以不用关注,问题代码如下:

import gym
from gym.spaces import Box
import numpy as np
from tensorflow import keras
import tensorflow as tf
import time

class MyWrapper(gym.ObservationWrapper):
    def __init__(self, env, encoder, latent_dim = 2):
        super().__init__(env)
        self._observation_space = Box(-np.inf, np.inf, shape=(7 + latent_dim,), dtype=np.float32)
        self.observation_space = self._observation_space
        self.encoder = encoder # 这是我已经提前训练好的模型
        tf.InteractiveSession()
        self.sess = tf.get_default_session()
		self.sess.run(tf.global_variables_initializer())
	
    def observation(self, obs):
        obs = np.reshape(obs, (1, -1))
        latent_z_tensor = self.encoder(obs)[2] # 问题就在与这里,这行代码在调用run时,会不断的创建图节点,所以越来越慢
        
        t=time.time() # 测试运行用时
        latent_z_arr = sels.sess.run(latent_z_tensor) # 每次run时,就会把上面的图重新构建一次
        print(time.time()-t) # 测试运行用时

        obs = np.reshape(obs, (-1,))
        latent_z_arr = np.reshape(latent_z_arr, (-1,))

        obs = obs.tolist()
        obs.extend(latent_z_arr.tolist())
        obs = np.array(obs)
        return obs

解决思路

在初始化时,就建立好图结构,使用tf.placeholder占位符表示obs这个变量,具体方案示例如下(可以只关注带有注释的行):

import gym
from gym.spaces import Box
import numpy as np
from tensorflow import keras
import tensorflow as tf
import time

class MyWrapper(gym.ObservationWrapper):
    def __init__(self, env, encoder, latent_dim = 2):
        super().__init__(env)
        self._observation_space = Box(-np.inf, np.inf, shape=(7 + latent_dim,), dtype=np.float32)
        self.observation_space = self._observation_space
        self.encoder = encoder
        tf.InteractiveSession()
        self.sess = tf.get_default_session()
        self.obs=tf.placeholder(dtype=tf.float32,shape=(1,7)) # 重点在于这两行代码,初始化时先构建好图,先用占位符表示obs,实际运行时只需喂数据obs就好了
        self.latent_z_tensor = self.encoder(self.obs)[2] # 在初始化时构建图
        self.sess.run(tf.global_variables_initializer())

    def observation(self, obs):
        obs = np.reshape(obs, (1, -1))
        t=time.time() # 测试运行用时
        latent_z_arr = self.sess.run(self.latent_z_tensor, feed_dict={self.obs:obs}) # 这里只需喂数据,不会重新构建图了。
        print(time.time()-t) # 测试运行用时

        obs = np.reshape(obs, (-1,))
        latent_z_arr = np.reshape(latent_z_arr, (-1,))

        obs = obs.tolist()
        obs.extend(latent_z_arr.tolist())
        obs = np.array(obs)
        return obs

现在,数据类型转换完成,代码运行慢也解决了!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
如果你已经训练好了一个目标检测模型,可以使用 `detect.py` 脚本来进行对象检测。可以将该脚本作为一个函数调用,以便在其他代码使用。 以下是一个示例代码,展示如何在 Python 调用 `detect.py` 脚本: ```python import argparse import torch from models import * from utils.datasets import * from utils.utils import * def detect(): parser = argparse.ArgumentParser() parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='*.cfg path') parser.add_argument('--names', type=str, default='data/coco.names', help='*.names path') parser.add_argument('--weights', type=str, default='weights/yolov3.weights', help='weights path') parser.add_argument('--source', type=str, default='data/samples', help='source') # input file/folder, 0 for webcam parser.add_argument('--output', type=str, default='output', help='output folder') # output folder parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.3, help='object confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.6, help='IOU threshold for NMS') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--classes', nargs='+', type=int, help='filter by class') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') parser.add_argument('--augment', action='store_true', help='augmented inference') opt = parser.parse_args() with torch.no_grad(): if opt.classes: opt.filter_classes = True device = torch_utils.select_device(opt.device) print(f'Using {device}') # Initialize model model = Darknet(opt.cfg, img_size=opt.img_size) attempt_download(opt.weights) if opt.weights.endswith('.pt'): # pytorch format model.load_state_dict(torch.load(opt.weights, map_location=device)['model']) else: # darknet format load_darknet_weights(model, opt.weights) model.to(device).eval() # Get dataloader dataset = LoadImages(opt.source, img_size=opt.img_size) # Run inference t0 = time.time() for path, img, im0s, vid_cap in dataset: img = torch.from_numpy(img).to(device) img = img.float() / 255.0 if img.ndimension() == 3: img = img.unsqueeze(0) # Inference pred = model(img)[0] # Apply NMS pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) # Process detections for i, det in enumerate(pred): # detections for image i p, s, im0 = path[i], f'{i}: ', im0s[i].copy() save_path = str(Path(opt.output) / Path(p).name) txt_path = str(Path(opt.output) / Path(p).stem) + ('' if opt.save_txt else '.nobak') + '.txt' s += '%gx%g ' % img.shape[2:] # print string gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh if det is not None and len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() # Print results for c in det[:, -1].unique(): n = (det[:, -1] == c).sum() # detections per class s += f'{n} {names[int(c)]}s, ' # add to string # Write results for *xyxy, conf, cls in det: if opt.save_txt: # Write to file xyxy = torch.tensor(xyxy).view(1, 4) # convert to tensor xyxy = xyxy / gn[:2] + torch.tensor([0, 0, 1, 1]).float().cuda() # normalized to img0 0-1 line = (cls, *xyxy[0], conf) if opt.save_txt else (cls, conf) with open(txt_path, 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') if opt.view_img: # Add bbox to image c = int(cls) # integer class label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}') plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=opt.line_thickness) # Print time (inference + NMS) print(f'{s}Done. ({time.time() - t0:.3f}s)') # Save results (image with detections) if opt.save_img: cv2.imwrite(save_path, im0) print(f'Done. ({time.time() - t0:.3f}s)') ``` 上述代码 `detect()` 函数包含了 `detect.py` 的全部代码。在这个函数,我们使用 `argparse` 模块来处理命令行参数,并在模型加载后对传入的图片进行目标检测。最后,我们可以选择将结果保存到文件或显示在屏幕上。 你可以在你自己的代码调用函数,以便在其他应用程序使用目标检测模型。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值