realbasicvsr代码解析part1

RealBasicVSR的Inference_realbasicvsr.py
main():代码解析:

def main():
    args = parse_args()

    # initialize the model    
    model = init_model(args.config, args.checkpoint)
    
    # read images
    file_extension = os.path.splitext(args.input_dir)[1]
    if file_extension in VIDEO_EXTENSIONS:  # input is a video file
        video_reader = mmcv.VideoReader(args.input_dir)#mmcv.VideoReader( )将视频读取成frame序列
        inputs = []
        for frame in video_reader:            
            inputs.append(np.flip(frame, axis=2))    #将图像的rgb  r与b进行交换,颜色通道改变。       
    elif file_extension == '':  # input is a directory
        inputs = []
        input_paths = sorted(glob.glob(f'{args.input_dir}/*'))
        for input_path in input_paths:
            img = mmcv.imread(input_path, channel_order='rgb')
            inputs.append(img)
    else:
        raise ValueError('"input_dir" can only be a video or a directory.')

    for i, img in enumerate(inputs):
        img = torch.from_numpy(img / 255.).permute(2, 0, 1).float()
        #permute()张量的维度转换      from_numpy:将数组转化为张量
        #print(img.shape)    torch.Size([3, 360, 480])
        inputs[i] = img.unsqueeze(0)
        #print(inputs[i].shape)  torch.Size([1, 3, 360, 480])    
    inputs = torch.stack(inputs, dim=1)#stack()对序列数据内部的张量进行扩维拼接
    #print(inputs.shape)    torch.Size([1, 100, 3, 360, 480])

    # map to cuda, if available
    cuda_flag = False
    if torch.cuda.is_available():
    
        print("====================")    
        torch.cuda.empty_cache()    #释放显存            
        model = model.cuda()
        print("====================")    
        torch.cuda.empty_cache()          
        cuda_flag = True
          
        
    #在with模块下,所有计算得出的tensor的requires_gard都自动设置为False
    with torch.no_grad():
        if isinstance(args.max_seq_len, int):
            outputs = []
            for i in range(0, inputs.size(1), args.max_seq_len):
                imgs = inputs[:, i:i + args.max_seq_len, :, :, :]
                if cuda_flag:
                    imgs = imgs.cuda()
                outputs.append(model(imgs, test_mode=True)['output'].cpu())
            outputs = torch.cat(outputs, dim=1)
        else:
            if cuda_flag:
                inputs = inputs.cuda()
                #print(inputs.shape) torch.Size([1, 100, 3, 360, 480])
                
            import time
            t1 = time.time()
            outputs = model(inputs, test_mode=True)['output'].cpu()#cpu()是因为numpy只能在cpu处理
            #print(outputs.shape)    torch.Size([1, 100, 3, 1440, 1920])
            t2 = time.time()
            print(t2-t1)
            
            # time.sleep(100)
            
    #print(os.path.splitext(args.output_dir)[1]) .mp4
    #print(os.path.splitext(args.output_dir)[0]) results/demo_001   

    if os.path.splitext(args.output_dir)[1] in VIDEO_EXTENSIONS:
        output_dir = os.path.dirname(args.output_dir)   #去掉文件名,返回目录
        mmcv.mkdir_or_exist(output_dir)

        h, w = outputs.shape[-2:]
        print(h,w)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')    # cv2.VideoWriter_fourcc输出视频的格式要求    
        video_writer = cv2.VideoWriter(args.output_dir, fourcc, args.fps,
                                       (w, h))  #写视频,输出路径,编码器,保存的视频帧率,画面尺寸
        for i in range(0, outputs.size(1)):
            img = tensor2img(outputs[:, i, :, :, :])    #tensor2img()将tensor转换为image
            video_writer.write(img.astype(np.uint8))
        cv2.destroyAllWindows() #关闭窗口
        video_writer.release()
    else:
        mmcv.mkdir_or_exist(args.output_dir)
        #print(outputs.size(1))  31
        for i in range(0, outputs.size(1)):
            output = tensor2img(outputs[:, i, :, :, :])
            filename = os.path.basename(input_paths[i])
            # print(filename) 返回最后一级的文件名
            if args.is_save_as_png:
                file_extension = os.path.splitext(filename)[1]
                filename = filename.replace(file_extension, '.png')                
            mmcv.imwrite(output, f'{args.output_dir}/{filename}')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值