博客摘录「 PyTorch 深度学习实践 第6讲」2023年4月4日

分布的差异:KL散度,cross-entropy交叉熵

import os 
import torch
import cv2
import argparse
import warnings
import numpy as np
from utils import PSNR, validation, LossNetwork
from model.IAT_main import IAT
from torchvision.transforms import Normalize
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser()
parser.add_argument('--file_name', type=str, default='/home/s3090/zzh/Illumination-Adaptive-Transformer-main/IAT_enhance/Your_Path/')
parser.add_argument('--normalize', type=bool, default=False)
parser.add_argument('--task', type=str, default='enhance', help='Choose from exposure or enhance')
config = parser.parse_args()

# Weights path
exposure_pretrain = r'best_Epoch_exposure.pth'
enhance_pretrain = r'/home/s3090/zzh/Illumination-Adaptive-Transformer-main/IAT_enhance/workdirs/grad3/best_Epoch.pth'

normalize_process = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

## Load Pre-train Weights
model = IAT().cuda()
if config.task == 'exposure':
    model.load_state_dict(torch.load(exposure_pretrain))
elif config.task == 'enhance':
    model.load_state_dict(torch.load(enhance_pretrain))
else:
    warnings.warn('Only could be exposure or enhance')
model.eval()

# 设置视频路径和输出路径
video_path = 'video.mp4'
output_path = 'output_video.mp4'

# 打开视频文件
cap = cv2.VideoCapture(video_path)

# 获取视频的宽度和高度
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# 设置输出视频的编解码器和帧率与输入视频相同
output_codec = cv2.VideoWriter_fourcc(*'mp4v')
output_fps = cap.get(cv2.CAP_PROP_FPS)

# 创建输出视频的写入器
out = cv2.VideoWriter(output_path, output_codec, output_fps, (frame_width, frame_height))

while True:
    # 读取视频帧
    ret, frame = cap.read()

    # 判断是否读取到视频帧
    if not ret:
        break

    # 将BGR图像转换为RGB图像
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # 将NumPy数组转换为PyTorch Tensor
    input = torch.from_numpy(frame_rgb / 255.0).float().cuda()
    input = input.permute(2, 0, 1).unsqueeze(0)

    if config.normalize:
        input = normalize_process(input)

    # 使用模型进行处理
    _, _, enhanced_img = model(input)
    enhanced_img = enhanced_img.cpu().squeeze(0).permute(1, 2, 0)
    enhanced_img = enhanced_img.detach().numpy()
    enhanced_img = (enhanced_img * 255.0).astype(np.uint8)

    # 将RGB图像转换为BGR图像
    enhanced_img_bgr = cv2.cvtColor(enhanced_img, cv2.COLOR_RGB2BGR)

    # 写入增强后的帧到输出视频
    out.write(enhanced_img_bgr)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值