RealBasicVSR的Inference_realbasicvsr.py
main():代码解析:
def main():
args = parse_args()
# initialize the model
model = init_model(args.config, args.checkpoint)
# read images
file_extension = os.path.splitext(args.input_dir)[1]
if file_extension in VIDEO_EXTENSIONS: # input is a video file
video_reader = mmcv.VideoReader(args.input_dir)#mmcv.VideoReader( )将视频读取成frame序列
inputs = []
for frame in video_reader:
inputs.append(np.flip(frame, axis=2)) #将图像的rgb r与b进行交换,颜色通道改变。
elif file_extension == '': # input is a directory
inputs = []
input_paths = sorted(glob.glob(f'{args.input_dir}/*'))
for input_path in input_paths:
img = mmcv.imread(input_path, channel_order='rgb')
inputs.append(img)
else:
raise ValueError('"input_dir" can only be a video or a directory.')
for i, img in enumerate(inputs):
img = torch.from_numpy(img / 255.).permute(2, 0, 1).float()
#permute()张量的维度转换 from_numpy:将数组转化为张量
#print(img.shape) torch.Size([3, 360, 480])
inputs[i] = img.unsqueeze(0)
#print(inputs[i].shape) torch.Size([1, 3, 360, 480])
inputs = torch.stack(inputs, dim=1)#stack()对序列数据内部的张量进行扩维拼接
#print(inputs.shape) torch.Size([1, 100, 3, 360, 480])
# map to cuda, if available
cuda_flag = False
if torch.cuda.is_available():
print("====================")
torch.cuda.empty_cache() #释放显存
model = model.cuda()
print("====================")
torch.cuda.empty_cache()
cuda_flag = True
#在with模块下,所有计算得出的tensor的requires_gard都自动设置为False
with torch.no_grad():
if isinstance(args.max_seq_len, int):
outputs = []
for i in range(0, inputs.size(1), args.max_seq_len):
imgs = inputs[:, i:i + args.max_seq_len, :, :, :]
if cuda_flag:
imgs = imgs.cuda()
outputs.append(model(imgs, test_mode=True)['output'].cpu())
outputs = torch.cat(outputs, dim=1)
else:
if cuda_flag:
inputs = inputs.cuda()
#print(inputs.shape) torch.Size([1, 100, 3, 360, 480])
import time
t1 = time.time()
outputs = model(inputs, test_mode=True)['output'].cpu()#cpu()是因为numpy只能在cpu处理
#print(outputs.shape) torch.Size([1, 100, 3, 1440, 1920])
t2 = time.time()
print(t2-t1)
# time.sleep(100)
#print(os.path.splitext(args.output_dir)[1]) .mp4
#print(os.path.splitext(args.output_dir)[0]) results/demo_001
if os.path.splitext(args.output_dir)[1] in VIDEO_EXTENSIONS:
output_dir = os.path.dirname(args.output_dir) #去掉文件名,返回目录
mmcv.mkdir_or_exist(output_dir)
h, w = outputs.shape[-2:]
print(h,w)
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # cv2.VideoWriter_fourcc输出视频的格式要求
video_writer = cv2.VideoWriter(args.output_dir, fourcc, args.fps,
(w, h)) #写视频,输出路径,编码器,保存的视频帧率,画面尺寸
for i in range(0, outputs.size(1)):
img = tensor2img(outputs[:, i, :, :, :]) #tensor2img()将tensor转换为image
video_writer.write(img.astype(np.uint8))
cv2.destroyAllWindows() #关闭窗口
video_writer.release()
else:
mmcv.mkdir_or_exist(args.output_dir)
#print(outputs.size(1)) 31
for i in range(0, outputs.size(1)):
output = tensor2img(outputs[:, i, :, :, :])
filename = os.path.basename(input_paths[i])
# print(filename) 返回最后一级的文件名
if args.is_save_as_png:
file_extension = os.path.splitext(filename)[1]
filename = filename.replace(file_extension, '.png')
mmcv.imwrite(output, f'{args.output_dir}/{filename}')