qwen_vl_utils代码分析

函数列表:

序号函数名简要说明
1round_by_factor(number: int, factor: int) -> int返回最接近 number,且能被 factor 整除的整数。
2ceil_by_factor(number: int, factor: int) -> int返回大于等于 number,且能被 factor 整除的最小整数。
3floor_by_factor(number: int, factor: int) -> int返回小于等于 number,且能被 factor 整除的最大整数。
4smart_resize(height: int, width: int, ...) -> tuple[int, int]根据给定的高和宽,调整图像尺寸,使其满足特定条件(如可被因数整除、像素数在范围内、保持长宽比)。
5to_rgb(pil_image: Image.Image) -> Image.Image将 PIL 图像转换为 RGB 模式,处理 RGBA 图像的透明通道。
6fetch_image(ele: dict, size_factor: int = IMAGE_FACTOR) -> Image.Image从各种输入(URL、本地路径、Base64、PIL.Image)获取图像,并进行尺寸调整。
7smart_nframes(ele: dict, total_frames: int, video_fps: int) -> int计算用于模型输入的视频帧数,确保帧数满足特定因数要求,并在最小和最大帧数范围内。
8_read_video_torchvision(ele: dict) -> (torch.Tensor, float)使用 torchvision 库读取视频,返回视频帧和帧率。
9is_decord_available() -> bool检查是否安装了 decord 库。
10_read_video_decord(ele: dict) -> (torch.Tensor, float)使用 decord 库读取视频,返回视频帧和帧率。
11get_video_reader_backend() -> str获取用于读取视频的后端库名称,优先使用 decord
12fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor读取并处理视频,返回处理后的视频帧。
13extract_vision_info(conversations: list) -> list[dict]从对话中提取与视觉相关的信息,如图像或视频。
14process_vision_info(conversations: list, return_video_kwargs: bool = False) -> tuple处理视觉信息,获取图像和视频数据,以供模型使用。

调用关系图示:

process_vision_info
├── extract_vision_info
├── fetch_image (对于图像)
│   ├── to_rgb
│   └── smart_resize
│       ├── round_by_factor
│       ├── ceil_by_factor
│       └── floor_by_factor
└── fetch_video (对于视频)
    ├── get_video_reader_backend
    │   └── is_decord_available
    ├── _read_video_torchvision 或 _read_video_decord
    │   └── smart_nframes
    │       ├── round_by_factor
    │       ├── ceil_by_factor
    │       └── floor_by_factor
    └── smart_resize
        ├── round_by_factor
        ├── ceil_by_factor
        └── floor_by_factor

说明:

  • process_vision_info 是核心函数,它根据对话内容,分别处理图像和视频。
  • 在处理图像时,主要通过 fetch_imageto_rgbsmart_resize 来获取并调整图像。
  • 在处理视频时,主要通过 fetch_videoget_video_reader_backend_read_video_torchvision_read_video_decordsmart_nframessmart_resize 来获取并处理视频帧。
  • 数学计算函数 round_by_factorceil_by_factorfloor_by_factor 被多次调用,用于确保尺寸和帧数满足特定的因数要求。
常量定义:
  • IMAGE_FACTOR = 28:这是图像尺寸调整的因数,图像的高度和宽度都将被调整为 28 的倍数。

  • MIN_PIXELS = 4 * 28 * 28:图像的最小像素数,确保图像不小于特定大小。

  • MAX_PIXELS = 16384 * 28 * 28:图像的最大像素数,限制图像的最大尺寸,防止过大的图像占用过多内存。

  • MAX_RATIO = 200:图像的最大宽高比,用于防止过度拉伸或压缩的图像。

  • VIDEO_MIN_PIXELS = 128 * 28 * 28:视频帧的最小像素数。

  • VIDEO_MAX_PIXELS = 768 * 28 * 28:视频帧的最大像素数。

  • FRAME_FACTOR = 2:视频帧数需要是此因数的倍数。

  • FPS = 2.0:默认的视频采样帧率。

  • FPS_MIN_FRAMES = 4:视频采样的最小帧数。

  • FPS_MAX_FRAMES = 768:视频采样的最大帧数。

  • VIDEO_TOTAL_PIXELS:从环境变量 VIDEO_MAX_PIXELS 中获取视频的总像素数,如果未设置,则默认使用 128000 * 28 * 28 * 0.9,并将其转换为整数。这是对视频输入尺寸的限制。

1. extract_vision_info 函数

功能:

extract_vision_info 函数用于从对话内容(conversations)中提取所有与视觉相关的信息(如图像或视频),并将这些信息以列表的形式返回。

参数:

  • conversations: 类型为 list[dict]list[list[dict]],表示对话的列表。每个对话可以是一个包含消息字典的列表,或者直接是消息字典。

返回值:

  • vision_infos: 类型为 list[dict],包含所有提取的视觉信息的字典。

代码解析:

  1. 初始化空列表 vision_infos

    vision_infos = []
    
  2. 确保 conversations 是列表的列表格式:

    if isinstance(conversations[0], dict):
        conversations = [conversations]
    
    • 如果 conversations 的第一个元素是字典,说明传入的是单个对话,而不是对话的列表。为了统一处理,将其包装成列表的列表形式。
  3. 遍历每个对话和消息:

    for conversation in conversations:
        for message in conversation:
    
  4. 检查消息的内容是否为列表:

    if isinstance(message["content"], list):
    
    • 如果消息的内容是列表,说明它可能包含多个视觉元素。
  5. 提取视觉信息:

    for ele in message["content"]:
        if (
            "image" in ele
            or "image_url" in ele
            or "video" in ele
            or ele["type"] in ("image", "image_url", "video")
        ):
            vision_infos.append(ele)
    
    • 遍历消息内容中的每个元素 ele
    • 如果元素包含 "image""image_url""video" 键,或者其类型(ele["type"])是 "image""image_url""video",则将该元素添加到 vision_infos 列表中。
  6. 返回提取的视觉信息列表:

    return vision_infos
    
2. process_vision_info 函数

功能:

process_vision_info 函数用于处理从对话内容中提取的视觉信息,包括读取和处理图像和视频数据,最终返回处理后的结果。

参数:

  • conversations: 类型为 list[dict]list[list[dict]],表示对话的列表。
  • return_video_kwargs: 类型为 bool,默认为 False。如果为 True,则在返回值中包含视频的额外参数(如帧率)。

返回值:

  • 根据 return_video_kwargs 的值,返回不同的内容:

    • 如果 return_video_kwargsFalse

      • (image_inputs, video_inputs)
        • image_inputs: 处理后的图像列表(list[Image.Image]),如果没有图像,则为 None
        • video_inputs: 处理后的视频列表(list[torch.Tensor]list[list[Image.Image]]),如果没有视频,则为 None
    • 如果 return_video_kwargsTrue

      • (image_inputs, video_inputs, {'fps': video_sample_fps_list})
        • 除了上述两个返回值外,额外返回一个包含视频帧率的字典。

代码解析:

  1. 提取视觉信息:

    vision_infos = extract_vision_info(conversations)
    
    • 调用 extract_vision_info 函数,从对话中提取所有的视觉信息,得到 vision_infos 列表。
  2. 初始化存储变量:

    image_inputs = []
    video_inputs = []
    video_sample_fps_list = []
    
    • image_inputs: 用于存储处理后的图像数据。
    • video_inputs: 用于存储处理后的视频数据。
    • video_sample_fps_list: 用于存储每个视频的采样帧率。
  3. 处理每个视觉信息:

    for vision_info in vision_infos:
        if "image" in vision_info or "image_url" in vision_info:
            image_inputs.append(fetch_image(vision_info))
        elif "video" in vision_info:
            video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
            video_sample_fps_list.append(video_sample_fps)
            video_inputs.append(video_input)
        else:
            raise ValueError("image, image_url or video should in content.")
    
    • 遍历 vision_infos 列表,对每个视觉信息进行处理。

    • 处理图像:

      • 如果 vision_info 中包含 "image""image_url" 键,调用 fetch_image 函数处理图像。
      • 将处理后的图像对象添加到 image_inputs 列表中。
    • 处理视频:

      • 如果 vision_info 中包含 "video" 键,调用 fetch_video 函数处理视频,参数 return_video_sample_fps=True 表示需要返回视频的采样帧率。
      • 得到处理后的视频数据 video_input 和视频帧率 video_sample_fps
      • 将视频数据添加到 video_inputs 列表,将帧率添加到 video_sample_fps_list 列表。
    • 异常处理:

      • 如果既不包含图像也不包含视频,抛出 ValueError,提示内容中应包含 "image""image_url""video"
  4. 处理可能的空列表:

    if len(image_inputs) == 0:
        image_inputs = None
    if len(video_inputs) == 0:
        video_inputs = None
    
    • 如果 image_inputsvideo_inputs 列表为空,则将其设置为 None
  5. 根据参数返回结果:

    if return_video_kwargs:
        return image_inputs, video_inputs, {'fps': video_sample_fps_list}
    return image_inputs, video_inputs
    
    • 如果 return_video_kwargsTrue,则返回包含视频帧率信息的字典。
    • 如果为 False(默认情形),则只返回图像和视频数据。

示例:

假设有如下对话内容:

conversations = [
    # 第一个对话
    [
        {'role': 'user', 'content': [
            {'type': 'text', 'data': '请查看这张图片。'},
            {'type': 'image', 'image_url': 'http://example.com/image1.jpg'}
        ]},
        {'role': 'assistant', 'content': '好的,我正在查看。'}
    ],
    # 第二个对话
    [
        {'role': 'user', 'content': [
            {'type': 'text', 'data': '这是一个视频。'},
            {'type': 'video', 'video': 'http://example.com/video1.mp4'}
        ]},
        {'role': 'assistant', 'content': '我正在处理视频。'}
    ]
]

调用 process_vision_info(conversations)

  1. 提取视觉信息:

    • extract_vision_info 函数遍历对话,找到包含视觉信息的元素。
    • 得到 vision_infos 列表,包含两个元素:
      • 第一个是图像信息:{'type': 'image', 'image_url': 'http://example.com/image1.jpg'}
      • 第二个是视频信息:{'type': 'video', 'video': 'http://example.com/video1.mp4'}
  2. 处理视觉信息:

    • 对于第一个元素,调用 fetch_image 读取并处理图像,结果添加到 image_inputs 列表。
    • 对于第二个元素,调用 fetch_video 读取并处理视频,结果添加到 video_inputs 列表,同时帧率添加到 video_sample_fps_list
  3. 返回结果:

    • image_inputs 是一个包含处理后图像的列表。
    • video_inputs 是一个包含处理后视频的列表。
    • 如果 return_video_kwargsTrue,还会返回视频帧率信息。
tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]

用于描述函数的返回值类型。让我们逐步解析这一复杂的类型注解,理解每个部分的含义。

  1. 元组类型 tuple[...]
  • tuple[...] 表示一个元组类型,元组中的每个元素的位置和类型都是固定的。
  • 元组的元素按照顺序依次对应。
  1. 第一个元素:list[Image.Image] | None
  • list[Image.Image]:表示一个 Image.Image 对象(来自 PIL 库)的列表,即图像对象的列表。
  • | None:符号 | 表示类型的联合(Union),即该值可以是前面的类型或后面的类型。
  • list[Image.Image] | None:表示该元素要么是一个 Image.Image 对象的列表,要么是 None

解释:函数可能返回一个包含图像的列表,如果没有图像,则返回 None

  1. 第二个元素:list[torch.Tensor | list[Image.Image]] | None
  • 内部类型 torch.Tensor | list[Image.Image]
    • torch.Tensor:表示一个 PyTorch 的张量,一般用于表示视频数据(如视频帧序列)。
    • list[Image.Image]:表示一个 Image.Image 对象的列表,即图像对象的列表。
    • torch.Tensor | list[Image.Image]:表示该元素可以是 torch.Tensor 或者 list[Image.Image]
  • 外部列表 list[...]:表示上述类型的列表,即列表中的每个元素可以是 torch.Tensorlist[Image.Image]
  • | None:表示该值也可以是 None
  • 组合起来 list[torch.Tensor | list[Image.Image]] | None:表示该元素要么是一个列表,列表中的每个元素是 torch.Tensorlist[Image.Image],要么是 None

解释:函数可能返回一个视频数据的列表,如果没有视频,则返回 None

  1. 第三个元素:Optional[dict]
  • Optional[dict]Optionaltyping 模块中的一个泛型类型,用于表示可选类型,即类型可以是指定的类型或 None
  • Optional[dict] 等价于 dict | None

解释:函数可能返回一个字典(如视频的额外参数),如果没有额外参数,则返回 None

示例

假设

  • 函数提取并处理了两张图像和一个视频。
  • 图像处理后得到一个 Image.Image 对象的列表。
  • 视频处理后得到一个 torch.Tensor,表示视频帧数据。
  • 视频的额外参数是帧率 fps

返回值可能是

(
    [image1, image2],                 # list[Image.Image]
    [video_tensor],                   # list[torch.Tensor]
    {'fps': [video_fps_value]}        # dict
)

或者,如果没有图像,只有视频

(
    None,                             # 没有图像
    [video_tensor],                   # list[torch.Tensor]
    {'fps': [video_fps_value]}        # dict
)

或者,如果只有图像,没有视频

(
    [image1, image2],                 # list[Image.Image]
    None,                             # 没有视频
    None                              # 没有额外参数
)
3. round_by_factor(number: int, factor: int) -> int

功能:将给定的数字 number 调整为最接近的、能被 factor 整除的整数。

实现

def round_by_factor(number: int, factor: int) -> int:
    return round(number / factor) * factor
4. ceil_by_factor(number: int, factor: int) -> int

功能:将给定的数字 number 调整为大于或等于它的、能被 factor 整除的最小整数。

实现

def ceil_by_factor(number: int, factor: int) -> int:
    return math.ceil(number / factor) * factor
5. floor_by_factor(number: int, factor: int) -> int

功能:将给定的数字 number 调整为小于或等于它的、能被 factor 整除的最大整数。

实现

def floor_by_factor(number: int, factor: int) -> int:
    return math.floor(number / factor) * factor
6. smart_resize(...) -> tuple[int, int]

功能:根据给定的高度和宽度,智能地调整图像尺寸,使其满足以下条件:

  1. 高度和宽度都能被指定的 factor 整除。
  2. 图像的总像素数在 min_pixelsmax_pixels 之间。
  3. 尽可能保持图像的宽高比。

参数

  • height: 原始高度。
  • width: 原始宽度。
  • factor: 因数,默认为 IMAGE_FACTOR(28)。
  • min_pixels: 最小像素数,默认为 MIN_PIXELS
  • max_pixels: 最大像素数,默认为 MAX_PIXELS

实现

def smart_resize(
    height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
    # 检查宽高比是否过大
    if max(height, width) / min(height, width) > MAX_RATIO:
        raise ValueError(
            f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
        )
    # 调整高度和宽度,使其能被factor整除
    h_bar = max(factor, round_by_factor(height, factor))
    w_bar = max(factor, round_by_factor(width, factor))
    # 调整像素数在指定范围内
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = floor_by_factor(height / beta, factor)
        w_bar = floor_by_factor(width / beta, factor)
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = ceil_by_factor(height * beta, factor)
        w_bar = ceil_by_factor(width * beta, factor)
    return h_bar, w_bar

解释

  • 首先检查图像的宽高比是否超过了 MAX_RATIO,如果超过,则抛出错误,防止图像过于拉伸或压缩。
  • 然后使用 round_by_factor 将高度和宽度调整为最接近的、能被 factor 整除的值,且不小于 factor
  • 接下来,根据图像总像素数与 max_pixelsmin_pixels 的关系,调整高度和宽度:
    • 如果调整后的像素数超过了 max_pixels,则计算一个缩放系数 beta,通过 floor_by_factor 函数减少高度和宽度。
    • 如果调整后的像素数小于了 min_pixels,则计算一个放大系数 beta,通过 ceil_by_factor 函数增加高度和宽度。
  • 最终返回调整后的高度和宽度。
7. to_rgb(pil_image: Image.Image) -> Image.Image

功能:将给定的 PIL 图像对象转换为 RGB 模式。如果图像是带有透明度的 RGBA 模式,则将其转换为 RGB 模式,并填充白色背景。

实现

def to_rgb(pil_image: Image.Image) -> Image.Image:
      if pil_image.mode == 'RGBA':
          white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
          white_background.paste(pil_image, mask=pil_image.split()[3])  # 使用alpha通道作为掩码
          return white_background
      else:
          return pil_image.convert("RGB")

解释

  • 检查图像的模式:
    • 如果是 RGBA 模式,表示图像带有透明度通道,需要将透明部分填充为白色。
    • 创建一个白色背景的 RGB 图像 white_background,大小与原图相同。
    • 使用 paste 方法,将原始图像粘贴到白色背景上,使用 alpha 通道作为掩码,以保留透明度信息。
    • 返回合成后的 RGB 图像。
  • 如果图像已经是其他模式,直接转换为 RGB 模式并返回。
8. fetch_image(ele: dict, size_factor: int = IMAGE_FACTOR) -> Image.Image

功能:根据给定的图像信息,从多种来源(如 URL、本地路径、Base64 编码、PIL.Image 对象)获取图像,并进行预处理,包括转换为 RGB 模式和调整尺寸。

参数

  • ele: 包含图像信息的字典。
  • size_factor: 调整尺寸的因数,默认为 IMAGE_FACTOR(28)。

实现

def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
    # 获取图像数据
    if "image" in ele:
        image = ele["image"]
    else:
        image = ele["image_url"]
    image_obj = None
    # 根据图像数据类型进行处理
    if isinstance(image, Image.Image):
        image_obj = image
    elif image.startswith("http://") or image.startswith("https://"):
        response = requests.get(image, stream=True)
        image_obj = Image.open(BytesIO(response.content))
    elif image.startswith("file://"):
        image_obj = Image.open(image[7:])
    elif image.startswith("data:image"):
        if "base64," in image:
            _, base64_data = image.split("base64,", 1)
            data = base64.b64decode(base64_data)
            image_obj = Image.open(BytesIO(data))
    else:
        image_obj = Image.open(image)
    if image_obj is None:
        raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
    # 转换为RGB模式
    image = to_rgb(image_obj)
    ## 调整尺寸
    if "resized_height" in ele and "resized_width" in ele:
        resized_height, resized_width = smart_resize(
            ele["resized_height"],
            ele["resized_width"],
            factor=size_factor,
        )
    else:
        width, height = image.size
        min_pixels = ele.get("min_pixels", MIN_PIXELS)
        max_pixels = ele.get("max_pixels", MAX_PIXELS)
        resized_height, resized_width = smart_resize(
            height,
            width,
            factor=size_factor,
            min_pixels=min_pixels,
            max_pixels=max_pixels,
        )
    image = image.resize((resized_width, resized_height))
    return image

解释

  1. 获取图像数据

    • 从传入的字典 ele 中获取图像信息,优先使用键 "image",否则使用 "image_url"
    • 初始化 image_objNone
  2. 根据图像数据的类型进行处理

    • 如果 image 是一个 Image.Image 对象,直接赋值给 image_obj
    • 如果 image 是以 "http://""https://" 开头的字符串,表示是网络 URL:
      • 使用 requests 库获取图像内容。
      • 使用 Image.open 读取图像。
    • 如果 image 是以 "file://" 开头的字符串,表示是本地文件路径:
      • 去除前缀 "file://",然后使用 Image.open 读取图像。
    • 如果 image"data:image" 开头,表示是 Base64 编码的图像数据:
      • 解析 Base64 数据,解码后使用 Image.open 读取图像。
    • 否则,假设 image 是本地文件路径,直接使用 Image.open 读取。
  3. 检查图像是否成功读取

    • 如果 image_obj 仍为 None,则抛出错误,提示无法识别的图像输入格式。
  4. 转换为 RGB 模式

    • 调用 to_rgb 函数,将图像转换为 RGB 模式,处理透明度问题。
  5. 调整图像尺寸

    • 如果 ele 中提供了 "resized_height""resized_width",则使用这些值进行尺寸调整,调用 smart_resize 函数。
    • 否则,使用图像的原始尺寸,并获取 min_pixelsmax_pixels(如果未提供,则使用默认值)。
    • 调用 smart_resize 函数,根据原始尺寸、因数和像素范围,计算新的高度和宽度。
    • 使用 image.resize 方法调整图像尺寸。
  6. 返回处理后的图像

    • 最终返回调整尺寸后的图像对象。
9. smart_nframes
def smart_nframes(
    ele: dict,
    total_frames: int,
    video_fps: int | float,
) -> int:
    ...

功能:

smart_nframes 函数用于计算用于模型输入的视频帧数,确保帧数满足一定的条件和限制。

参数:

  • ele: 包含视频配置信息的字典,支持以下键:
    • nframes: 希望提取的帧数。
    • fps: 希望以多少帧率来提取帧。
    • min_frames: 当使用 fps 时,指定最小帧数。
    • max_frames: 当使用 fps 时,指定最大帧数。
  • total_frames: 视频的总帧数。
  • video_fps: 视频的原始帧率。

流程:

  1. 检查冲突参数:

    assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
    

    这一步确保 ele 字典中不能同时既有 fps 又有 nframes,否则抛出断言错误。

  2. 根据配置计算帧数:

    • 如果提供了 nframes

      if "nframes" in ele:
          nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
      

      使用 round_by_factor 函数将 nframes 四舍五入到最近的 FRAME_FACTOR 的倍数,确保帧数是特定因子的整数倍。

    • 如果提供了 fps

      else:
          fps = ele.get("fps", FPS)
          min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
          max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
          nframes = total_frames / video_fps * fps
          if nframes > total_frames:
              logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
          nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
          nframes = floor_by_factor(nframes, FRAME_FACTOR)
      
      • 获取期望的 fps,如果未提供则使用默认值 FPS
      • 计算 min_framesmax_frames,确保它们是 FRAME_FACTOR 的倍数。
      • 根据原始总帧数、原始帧率和期望的帧率计算需要的帧数 nframes
      • 发出警告如果计算的 nframes 超过了总帧数。
      • nframes 限制在 min_framesmax_frames 之间,并确保不超过总帧数。
      • 使用 floor_by_factornframes 向下取整到最近的 FRAME_FACTOR 的倍数。
  3. 验证帧数是否合理:

    if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
        raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
    

    确保计算出的 nframes 在有效范围内,否则抛出 ValueError

  4. 返回计算的帧数:

    return nframes
    

10. _read_video_torchvision
def _read_video_torchvision(
    ele: dict,
) -> (torch.Tensor, float):
    ...

功能:

使用 torchvision 库的 io.read_video 函数读取视频文件,并返回视频帧的张量和采样后的帧率。

参数:

  • ele: 包含视频配置信息的字典,支持以下键:
    • video: 视频路径,支持本地路径、file://http://https://
    • video_start: 视频起始时间(秒)。
    • video_end: 视频结束时间(秒)。

流程:

  1. 处理视频路径:

    video_path = ele["video"]
    if version.parse(torchvision.__version__) < version.parse("0.19.0"):
        if "http://" in video_path or "https://" in video_path:
            warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
        if "file://" in video_path:
            video_path = video_path[7:]
    

    如果 torchvision 版本低于 0.19.0:

    • 不支持通过 http://https:// 读取视频,提示用户升级。
    • 如果视频路径以 file:// 开头,去掉前面的 file://
  2. 读取视频:

    st = time.time()
    video, audio, info = io.read_video(
        video_path,
        start_pts=ele.get("video_start", 0.0),
        end_pts=ele.get("video_end", None),
        pts_unit="sec",
        output_format="TCHW",
    )
    

    使用 io.read_video 读取视频,指定起始和结束时间,输出格式为 (T, C, H, W),即帧数、通道数、高度、宽度。

  3. 获取视频信息:

    total_frames, video_fps = video.size(0), info["video_fps"]
    logger.info(f"torchvision:  {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
    

    获取视频的总帧数和原始帧率,记录读取时间。

  4. 计算需要的帧数:

    nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
    

    调用之前的 smart_nframes 函数计算需要的帧数。

  5. 从视频中采样帧:

    idx = torch.linspace(0, total_frames - 1, nframes).round().long()
    sample_fps = nframes / max(total_frames, 1e-6) * video_fps
    video = video[idx]
    
    • 使用 torch.linspace 生成一个索引列表,从视频帧中均匀采样 nframes 帧。
    • 计算采样后的帧率 sample_fps
    • 根据索引提取对应的帧。
  6. 返回视频张量和采样帧率:

    return video, sample_fps
    

11. is_decord_available
def is_decord_available() -> bool:
    import importlib.util
    return importlib.util.find_spec("decord") is not None

功能:

检查 decord 库是否可用。

流程:

  • 使用 importlib.util.find_spec("decord") 检查是否可以找到 decord 模块的规格(spec)。
  • 如果找到了则返回 True,否则返回 False

12. _read_video_decord
def _read_video_decord(
    ele: dict,
) -> (torch.Tensor, float):
    ...

功能:

使用 decord 库的 VideoReader 读取视频文件,并返回视频帧的张量和采样后的帧率。

参数:

  • ele: 包含视频配置信息的字典,支持以下键:
    • video: 视频路径,支持本地路径、file://http://https://
    • video_start: 视频起始时间(暂不支持)。
    • video_end: 视频结束时间(暂不支持)。

流程:

  1. 导入 decord 库:

    import decord
    
  2. 处理视频路径:

    video_path = ele["video"]
    st = time.time()
    

    获取视频路径,记录开始时间。

  3. 创建 VideoReader 实例:

    vr = decord.VideoReader(video_path)
    

    使用 decordVideoReader 读取视频。

  4. 暂不支持起始和结束时间:

    if 'video_start' in ele or 'video_end' in ele:
        raise NotImplementedError("not support start_pts and end_pts in decord for now.")
    

    目前暂不支持通过 decord 指定起始和结束时间,如果发现有这样的参数,抛出 NotImplementedError

  5. 获取视频信息:

    total_frames, video_fps = len(vr), vr.get_avg_fps()
    logger.info(f"decord:  {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
    

    获取视频的总帧数和平均帧率,记录读取时间。

  6. 计算需要的帧数:

    nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
    

    调用 smart_nframes 计算需要的帧数。

  7. 从视频中采样帧:

    idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
    video = vr.get_batch(idx).asnumpy()
    video = torch.tensor(video).permute(0, 3, 1, 2)  # Convert to TCHW format
    sample_fps = nframes / max(total_frames, 1e-6) * video_fps
    
    • 使用 torch.linspace 生成索引列表,均匀采样 nframes 帧。
    • 使用 vr.get_batch(idx) 获取对应帧,转换为 NumPy 数组。
    • 将 NumPy 数组转换为 PyTorch 张量,并调整维度顺序为 (T, C, H, W)
    • 计算采样后的帧率。
  8. 返回视频张量和采样帧率:

    return video, sample_fps
    

13. get_video_reader_backend
@lru_cache(maxsize=1)
def get_video_reader_backend() -> str:
    if FORCE_QWENVL_VIDEO_READER is not None:
        video_reader_backend = FORCE_QWENVL_VIDEO_READER
    elif is_decord_available():
        video_reader_backend = "decord"
    else:
        video_reader_backend = "torchvision"
    print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
    return video_reader_backend

功能:

根据环境变量或库的可用性,确定使用哪个视频读取后端。

流程:

  1. 检查环境变量:

    if FORCE_QWENVL_VIDEO_READER is not None:
        video_reader_backend = FORCE_QWENVL_VIDEO_READER
    

    如果环境变量 FORCE_QWENVL_VIDEO_READER 被设置,则强制使用该后端。

  2. 检查 decord 库是否可用:

    elif is_decord_available():
        video_reader_backend = "decord"
    

    如果 decord 库可用,则使用 decord

  3. 默认使用 torchvision

    else:
        video_reader_backend = "torchvision"
    

    如果不满足上述条件,默认使用 torchvision

  4. 输出使用的后端信息并返回:

    print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
    return video_reader_backend
    

    打印使用的后端信息,返回后端名称。

注解:

  • 使用了 @lru_cache(maxsize=1) 装饰器,表示函数的返回值会被缓存,当再次调用时直接返回缓存值,避免重复计算。

14. fetch_video
def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
    ...

功能:

根据提供的配置,获取并处理视频数据,返回适用于模型输入的视频张量或图像列表。

参数:

  • ele: 包含视频配置信息的字典,支持以下键:
    • video: 视频路径,或包含一系列图像的列表。
    • 其他参数如 min_pixelsmax_pixelsresized_heightresized_width 等,用于调整视频尺寸。
  • image_factor: 调整尺寸时使用的因子,默认值为 IMAGE_FACTOR
  • return_video_sample_fps: 是否返回采样后的帧率,布尔值。

流程:

  1. 判断 ele["video"] 的类型:

    if isinstance(ele["video"], str):
        ...
    else:
        ...
    
    • 如果是字符串,表示视频路径,需要读取视频文件。
    • 如果是列表或元组,表示已经提供了帧图像的列表。
  2. 处理视频文件:

    video_reader_backend = get_video_reader_backend()
    try:
        video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
    except Exception as e:
        logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
        video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
    
    • 使用 get_video_reader_backend() 确定后端,然后调用对应的读取函数获取视频张量和采样帧率。
    • 如果发生异常,记录警告信息,默认使用 torchvision 读取视频。
  3. 获取视频尺寸信息和像素限制:

    nframes, _, height, width = video.shape
    min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
    total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
    max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
    max_pixels_supposed = ele.get("max_pixels", max_pixels)
    if max_pixels_supposed > max_pixels:
        logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
    max_pixels = min(max_pixels_supposed, max_pixels)
    
    • 获取视频的帧数、高度和宽度。
    • 计算 min_pixelsmax_pixels,以限制视频的总像素数,避免内存占用过大。
  4. 调整视频帧尺寸:

    if "resized_height" in ele and "resized_width" in ele:
        resized_height, resized_width = smart_resize(
            ele["resized_height"],
            ele["resized_width"],
            factor=image_factor,
        )
    else:
        resized_height, resized_width = smart_resize(
            height,
            width,
            factor=image_factor,
            min_pixels=min_pixels,
            max_pixels=max_pixels,
        )
    video = transforms.functional.resize(
        video,
        [resized_height, resized_width],
        interpolation=InterpolationMode.BICUBIC,
        antialias=True,
    ).float()
    
    • 如果提供了 resized_heightresized_width,则使用这些值进行尺寸调整。
    • 否则,使用 smart_resize 根据原始尺寸和像素限制计算新的高度和宽度。
    • 使用 transforms.functional.resize 调整视频帧尺寸。
  5. 返回结果:

    if return_video_sample_fps:
        return video, sample_fps
    return video
    
    • 如果需要返回采样帧率,则返回 (video, sample_fps)
    • 否则,只返回视频张量。
  6. 处理帧图像列表:

    else:
        assert isinstance(ele["video"], (list, tuple))
        process_info = ele.copy()
        process_info.pop("type", None)
        process_info.pop("video", None)
        images = [
            fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
            for video_element in ele["video"]
        ]
        nframes = ceil_by_factor(len(images), FRAME_FACTOR)
        if len(images) < nframes:
            images.extend([images[-1]] * (nframes - len(images)))
        if return_video_sample_fps:
            return images, process_info.pop("fps", 2.0)
        return images
    
    • 如果 ele["video"] 是一个图像列表,遍历每一帧图像,调用 fetch_image 处理。
    • 确保总帧数是 FRAME_FACTOR 的倍数,不足的话用最后一帧填充。
    • 根据是否需要返回采样帧率,返回结果。

decordtorchvision

1. Decord

Decord 是一个专为深度学习和视频处理设计的高性能视频读取库。它旨在提供高效、简洁、易用的视频数据加载接口,方便在深度学习模型中使用视频数据。

主要特点:
  • 高性能: Decord 使用多线程和高效的解码技术,能够快速读取和解码视频数据,大大提高了视频数据处理的效率。
  • 易于集成: 提供了与主流深度学习框架(如 PyTorch、MXNet 等)兼容的接口,可以直接将视频数据转换为框架支持的张量格式。
  • 随机访问: 支持对视频帧的随机访问,方便进行数据增强和批量处理。
  • 轻量级: Decord 旨在提供最小的依赖和轻量级的包装,以减少安装和使用的复杂性。
使用示例:
import decord
from decord import VideoReader
decord.bridge.set_bridge('torch')  # 设置与 PyTorch 兼容的桥接

# 创建视频读取器
vr = VideoReader('path/to/your/video.mp4')

# 获取视频的总帧数
total_frames = len(vr)

# 读取特定帧,例如第10帧
frame_10 = vr[9]  # 索引从0开始

# 批量读取帧
indices = [0, 5, 10, 15, 20]
frames = vr.get_batch(indices)  # 返回指定帧的批量数据

2. Torchvision

Torchvision 是 PyTorch 官方的计算机视觉工具包,提供了常用的数据集、模型和图像视频处理工具。它是 PyTorch 生态系统中处理视觉数据的核心库。

主要组件:
  • torchvision.datasets 提供常用的计算机视觉数据集,如 MNIST、CIFAR10、ImageNet 等的下载和加载接口。
  • torchvision.models 包含预训练的深度学习模型,如 ResNet、AlexNet、VGG 等,可用于迁移学习和特征提取。
  • torchvision.transforms 提供一系列图像预处理和数据增强的方法,如裁剪、缩放、翻转、归一化等。
  • torchvision.io 提供读取和写入图像、视频数据的接口,包括 read_imageread_video 等方法。
使用示例:

图像处理:

from torchvision import transforms
from PIL import Image

# 定义图像转换方法
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),  # 将图像转换为张量,并将像素值归一化到 [0,1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化
                         std=[0.229, 0.224, 0.225])
])

# 加载和处理图像
image = Image.open('path/to/your/image.jpg')
image_tensor = transform(image)

视频处理:

import torchvision.io as io

# 读取视频
video_path = 'path/to/your/video.mp4'
video, audio, info = io.read_video(video_path, pts_unit='sec')

# video 是形状为 [T, H, W, C] 的张量,T 是帧数
# 可以进行帧采样或其他处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值