引子
在开发fastnlp过程遇到一个需求:嵌套使用tdqm进度条,能够在任意位置使用print而不发生冲突。为解决print打印内容乱序的问题,我们将内置的write函数置换为tqdm.write函数。
import sys
import time
import contextlib
from tqdm import tqdm
class DummyFile:
def __init__(self, file):
if file is None:
file = sys.stderr
self.file = file
def write(self, x):
if len(x.rstrip()) > 0:
tqdm.write(str(x), file=self.file)
@contextlib.contextmanager
def redirect_stdout(file=None):
if file is None:
file = sys.stderr
old_stdout = file
sys.stdout.write = DummyFile(file).write
yield
sys.stdout.write = old_stdout.write
with redirect_stdout():
with tqdm(total=10, desc="epoch tqdm") as epoch_tqdm:
for ep in range(10):
with tqdm(total=100, desc="batch tqdm", leave=True) as batch_tqdm:
for batch in range(100):
print(batch)
batch_tqdm.update()
time.sleep(0.001)
epoch_tqdm.update()
而实际运行出来的结果见下图,batch tqdm被取消了而epoch tqdm没有被取消且多出来一个空白行。其出现问题的原因见下面的debug视频。
tqdm debug视频
从视频中可以推理出进度条取消的原理:只要在写时不使用\n的换行符,那么输出光标就在当前行,写入\r可以跳到当前行的首字符处,然后计算出当前行已经写的字符个数,写入相同个数的空字符和\r就可以覆盖之前的内容并跳到当前行首字符处。这里不妨自己写代码测试一下:
import sys
sys.stdout.write("12345")
sys.stdout.flush()
sys.stdout.write("\r"+" "*5+"\r")
sys.stdout.flush()
sys.stdout.write("hello")
sys.stdout.flush()
从图可以知道最终在终端输出内容为hello;这是因为输入了"\r"+" “*5+”\r",将首先输入的内容覆盖了。这也是tqdm能够取消当前行的原理。
rich.progress介绍
由于tqdm存在问题且要解决有点麻烦,在他人的推荐下发现一款新的进度条展示的库rich,其功能强大且完全满足项目的需求。故在最新版的fastnlp会采用rich代替tqdm。
按照官方介绍:Rich is a Python library for rich text and beautiful formatting in the terminal.
其中rich的progress模块提供了类比tqdm的进度条展示方法且更加强大,如图所示是官方给出的图片展示
对于引子提到的问题,使用rich的一种实现方法为:
from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TimeRemainingColumn
import time
with Progress(TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn(),
TimeElapsedColumn()) as progress:
epoch_tqdm = progress.add_task(description="epoch progress", total=10)
batch_tqdm = progress.add_task(description="batch progress", total=100)
for ep in range(10):
for batch in range(100):
print("ep: {} batch: {}".format(ep, batch))
progress.advance(batch_tqdm, advance=1)
time.sleep(0.1)
progress.advance(epoch_tqdm, advance=1)
progress.reset(batch_tqdm)
具体而言,可以将rich的Progress当成tqdm,rich的track当成tnrange;不同的是Progress是一个调度器,管理着不同的task,每个task是一个进度条任务,task之间可以视作是不同的线程,线程间使用RLOCK锁互斥使用终端资源。
rich.progress功能模块
从下图中可以看出进度条组件模块包括:RenderableColumn,SpinnerColumn,TextColumn,BarColumn,TimeElapsedColumn,TimeRemainingColumn,FileSizeColumn,TotalFileSizeColumn,TransferSpeedColumn,DownloadColumn。这些模块都是组成进度条的一部分,基本满足了所有需求,如果有特别需求还可以自己定制。其使用方法也很简单,只需要在Progress中传入具体模块就行,见下面代码和图:
progress = Progress(TextColumn("[progress.description]{task.description}"),
SpinnerColumn(),
BarColumn(),
FileSizeColumn(),
TotalFileSizeColumn(),
DownloadColumn(),
TransferSpeedColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn(),
TimeElapsedColumn())
见下图,progress中最重要的对象是Task和Progress,其中Task用来存储进度条展示所需的数据如description,completed,finished_time等,最重要的东西是Task拥有一个线程锁RLock,是线程安全的。Progress对象是一个Task的调度器,保存了进度条的组件,task池,线程锁RLock,还有Live实例,通过Live来控制进度条输出到终端的内容以及重定向问题,使得任意位置使用print而不出现乱序问题。具体细节可以自行浏览源码。
还有一个特别提醒,在使用progress时如果不用with语法,一定要记得加上progress.start(),不然进度条不会展示出来。这是因为进度条展示需要live的配合,而live的启动progress是使用了__enter__魔法配置了,若不用with语法则需要手动调用start方法,见下图。
progress = Progress(TextColumn("[progress.description]{task.description}"),
SpinnerColumn(),
BarColumn(),
FileSizeColumn(),
TotalFileSizeColumn(),
DownloadColumn(),
TransferSpeedColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn(),
TimeElapsedColumn())
epoch_tqdm = progress.add_task(description="epoch progress", total=10)
batch_tqdm = progress.add_task(description="batch progress", total=100)
progress.start() ## 开启
for ep in range(10):
for batch in range(100):
print("ep: {} batch: {}".format(ep, batch))
progress.advance(batch_tqdm, advance=1)
time.sleep(0.1)
progress.advance(epoch_tqdm, advance=1)
progress.reset(batch_tqdm)