几种优雅地使用tqdm进度条和joblib的方法

tqdm 是一个非常有用的 Python 库,用于在迭代过程中显示进度条。它的使用非常简单,可以在长时间运行的循环中提供实时的进度反馈。以下是一些常见的使用场景和示例代码:

基本使用

from tqdm import tqdm
import time

# 示例:简单的 for 循环
for i in tqdm(range(100)):
    time.sleep(0.1)  # 模拟一些耗时操作

在这个示例中,tqdm 将在标准输出上显示一个进度条,指示循环的进度。

与 list 结合使用

my_list = list(range(100))

for item in tqdm(my_list):
    time.sleep(0.1)

joblib 则是一个用于在 Python 中进行并行计算和磁盘缓存的库。它的主要功能包括并行执行任务、序列化 Python 对象以及缓存计算结果。

并行计算

joblib 的 Parallel 和 delayed 函数可以方便地进行并行计算。例如,假设我们有一个需要并行处理的函数:

from joblib import Parallel, delayed
import time

def process_item(item):
    time.sleep(0.1)  # 模拟耗时操作
    return item * 2

# 串行处理
results = [process_item(i) for i in range(10)]

# 并行处理
results = Parallel(n_jobs=4)(delayed(process_item)(i) for i in range(10))

print(results)

在joblib 并行时使用tqdm

在joblib 并行时使用tqdm,可以美化joblib 原本自带的verbose参数,给人一个比较直观的进度条,一些文章讨论了这一方面的实现,例如Tracking progress of joblib.Parallel executionHow can we use tqdm in a parallel execution with joblib? 我将这些整理如下:

1.最简单的执行

将tqdm直接嵌套如delayed 中,缺点是tqdm记录的是任务开始,而不是任务完成的状态

from math import sqrt
from joblib import Parallel, delayed  
from tqdm import tqdm  
result = Parallel(n_jobs=2)(delayed(sqrt)(i ** 2) for i in tqdm(range(100000)))

2.包装 Parallel 的代码

包装一个新的 Parallel类用以取代 Parallel,从而实现嵌入tqdm,在n_job = 1时会报错,所以请注意

from joblib import Parallel, delayed
import tqdm
class ParallelTqdm(Parallel):
    """joblib.Parallel, but with a tqdm progressbar

    Additional parameters:
    ----------------------
    total_tasks: int, default: None
        the number of expected jobs. Used in the tqdm progressbar.
        If None, try to infer from the length of the called iterator, and
        fallback to use the number of remaining items as soon as we finish
        dispatching.
        Note: use a list instead of an iterator if you want the total_tasks
        to be inferred from its length.

    desc: str, default: None
        the description used in the tqdm progressbar.

    disable_progressbar: bool, default: False
        If True, a tqdm progressbar is not used.

    show_joblib_header: bool, default: False
        If True, show joblib header before the progressbar.

    Removed parameters:
    -------------------
    verbose: will be ignored


    Usage:
    ------
    >>> from joblib import delayed
    >>> from time import sleep
    >>> ParallelTqdm(n_jobs=-1)([delayed(sleep)(.1) for _ in range(10)])
    80%|████████  | 8/10 [00:02<00:00,  3.12tasks/s]

    """

    def __init__(
        self,
        *,
        total_tasks: int | None = None, #如果你遇到TypeError,说明Python<3.10,请把类型提示去掉,或者from typing import Optional,Optional[int] = None,
        desc: str | None = None,
        disable_progressbar: bool = False,
        show_joblib_header: bool = False,
        **kwargs
    ):
        if "verbose" in kwargs:
            raise ValueError(
                "verbose is not supported. "
                "Use show_progressbar and show_joblib_header instead."
            )
        super().__init__(verbose=(1 if show_joblib_header else 0), **kwargs)
        self.total_tasks = total_tasks
        self.desc = desc
        self.disable_progressbar = disable_progressbar
        self.progress_bar: tqdm.tqdm | None = None

    def __call__(self, iterable):
        try:
            if self.total_tasks is None:
                # try to infer total_tasks from the length of the called iterator
                try:
                    self.total_tasks = len(iterable)
                except (TypeError, AttributeError):
                    pass
            # call parent function
            return super().__call__(iterable)
        finally:
            # close tqdm progress bar
            if self.progress_bar is not None:
                self.progress_bar.close()

    __call__.__doc__ = Parallel.__call__.__doc__

    def dispatch_one_batch(self, iterator):
        # start progress_bar, if not started yet.
        if self.progress_bar is None:
            self.progress_bar = tqdm.tqdm(
                desc=self.desc,
                total=self.total_tasks,
                disable=self.disable_progressbar,
                unit="tasks",
            )
        # call parent function
        return super().dispatch_one_batch(iterator)

    dispatch_one_batch.__doc__ = Parallel.dispatch_one_batch.__doc__

    def print_progress(self):
        """Display the process of the parallel execution using tqdm"""
        # if we finish dispatching, find total_tasks from the number of remaining items
        if self.total_tasks is None and self._original_iterator is None:
            self.total_tasks = self.n_dispatched_tasks
            self.progress_bar.total = self.total_tasks
            self.progress_bar.refresh()
        # update progressbar
        self.progress_bar.update(self.n_completed_tasks - self.progress_bar.n)

3.上下文管理器

包装出上下文管理器,代码比较简洁,当n_job=1时不管用

import contextlib

import joblib
from tqdm.autonotebook import tqdm


@contextlib.contextmanager
def tqdm_joblib(*args, **kwargs):
    """Context manager to patch joblib to report into tqdm progress bar
    given as argument"""

    tqdm_object = tqdm(*args, **kwargs)

    class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

        def __call__(self, *args, **kwargs):
            tqdm_object.update(n=self.batch_size)
            return super().__call__(*args, **kwargs)

    old_batch_callback = joblib.parallel.BatchCompletionCallBack
    joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
    try:
        yield tqdm_object
    finally:
        joblib.parallel.BatchCompletionCallBack = old_batch_callback
        tqdm_object.close()



def ParallelPbar(desc=None, **tqdm_kwargs):

    class Parallel(joblib.Parallel):
        def __call__(self, it):
            it = list(it)
            with tqdm_joblib(total=len(it), desc=desc, **tqdm_kwargs):
                return super().__call__(it)

    return Parallel


#用法

from math import sqrt
from joblib import Parallel, delayed

with tqdm_joblib(desc="My calculation", total=10) as progress_bar:
    Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))

# or

ParallelPbar("My calculation")(n_jobs=16)(
    delayed(sqrt)(i**2) for i in range(10)
)

  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值