模型推理之前要先进行warm up预热
1.方法1:
参考:torch.cuda.synchronize()同步统计pytorch调用cuda运行时间_torch.xpu.synchorize()-CSDN博客
import time
with torch.no_grad():
#模型先进行预热
for i in range(3):
image_features = model.encode_image(image)
torch.cuda.synchronize() #CUDA清空
start = time.time()
image_features = model.encode_image(image)
torch.cuda.synchronize()
print(f"image infer time is:{time.time() - start:.5f}")
2.方法2
参考:Chinese-CLIP中的代码
https://github.com/OFA-Sys/Chinese-CLIP/blob/master/cn_clip/deploy/speed_benchmark.py
import time
import numpy as np
from typing import List
from contextlib import contextmanager
def print_timings(name: str, timings: List[float]) -> None:
"""
Format and print inference latencies.
:param name: inference engine name
:param timings: latencies measured during the inference
"""
mean_time = 1e3 * np.mean(timings)
std_time = 1e3 * np.std(timings)
min_time = 1e3 * np.min(timings)
max_time = 1e3 * np.max(timings)
median, percent_95_time, percent_99_time = 1e3 * np.percentile(timings, [50, 95, 99])
print(
f"[{name}] "
f"mean={mean_time:.2f}ms, "
f"sd={std_time:.2f}ms, "
f"min={min_time:.2f}ms, "
f"max={max_time:.2f}ms, "
f"median={median:.2f}ms, "
f"95p={percent_95_time:.2f}ms, "
f"99p={percent_99_time:.2f}ms"
)
@contextmanager
def track_infer_time(buffer: List[int]) -> None:
"""
A context manager to perform latency measures
:param buffer: a List where to save latencies for each input
"""
start = time.perf_counter()
yield
end = time.perf_counter()
buffer.append(end - start)
with torch.no_grad():
#预热
for i in range(3):
pytorch_output = pt_model(image=image, text=text)
time_buffer = list()
for i in range(100):
with track_infer_time(time_buffer):
pytorch_output = pt_model(image=image, text=text)
image_feature = pytorch_output[0]
text_feature = pytorch_output[1]
text_probs = (100.0 * image_feature @ text_feature.T).softmax(dim=-1)
print_timings(name=f"Pytorch text inference speed (batch-size: {args.batch_size}):",
timings=time_buffer)
del pt_model
#最后的输出为:
[Pytorch text inference speed:] mean=156.97ms, sd=383.19ms, min=115.44ms, max=3969.65ms, median=118.11ms, 95p=122.01ms, 99p=164.81ms
其中mean是平均时间,sd是标准差,95p是百分之95的情况都在122.01秒之内