程序测速:两种方法测试结果
方法一:
def time_synchronized():
torch.cuda.synchronize() if torch.cuda.is_available() else None
return time.time()
方法二:
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
# model.forward()
ender.record()
torch.cuda.synchronize()
curr_timing = starter.elapsed_time(ender)
print("推理时间:", curr_timing)
完成代码及结果:
""" 测速 """
import torch
import torchvision
import cv2
import numpy as np
import time
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class STN(torch.nn.Module):
def __init__(self):
super(STN, self).__init__()
self.localization0 = torch.nn.Sequential(
torch.nn.Conv2d(2, 16