tqdm.cpp 功能拓展
之前我们分析了tqdm.cpp
的实现,但是这个库实现的功能并不完善,这里我们对这个库进行功能拓展,实现一个简易版的tqdm
。
首先我们来分析一下python版本的tqdm
的实现,然后结合之前的tqdm.cpp
的实现,来扩展一下tqdm.cpp
的功能,实现一个简易cpp版本的tqdm
的功能。
tqdm python实现
tqdm
的主要功能实现在std.py
中的tqdm
类中,现在我们来分析一下tqdm
类的实现。
在使用tqdm
的时候,我们一般会这样使用:
from tqdm import tqdm
for i in tqdm(range(100)):
pass
根据python的语法,我们可以知道,tqdm
是一个类,tqdm(range(100))
返回一个可迭代对象,我们可以使用for
循环来遍历这个可迭代对象,也可以使用next
函数来获取这个可迭代对象的下一个元素。我们来看一下tqdm
类的关于可迭代对象的实现:__iter__
class tqdm(Comparable):
... # 省略部分代码
def __iter__(self):
iterable = self.iterable
if self.disable:
for obj in iterable:
yield obj
return
... # 省略部分代码
try:
for obj in iterable:
yield obj
n += 1
if n - last_print_n >= self.miniters:
cur_t = time()
dt = cur_t - last_print_t
if dt >= mininterval and cur_t >= min_start_t:
self.update(n - last_print_n)
last_print_n = self.last_print_n
last_print_t = self.last_print_t
finally:
self.n = n
self.close()
从__iter__
实现中我们可以看到,这个方法通过yield
关键字返回一个可迭代对象,这个可迭代对象的每一个元素都是iterable
的元素,同时,每次迭代的时候,判断迭代次数是否大于miniters
,迭代时间是否大于mininterval
,如果满足条件,则调用update
方法,那update
方法应该有关于进度条的实现,我们来分析一下update
方法实现。
update函数实现
class tqdm(Comparable):
def update(self, n=1):
... # 省略部分代码,条件判断
if self.n - self.last_print_n >= self.miniters:
cur_t = self._time()
dt = cur_t - self.last_print_t
if dt >= self.mininterval and cur_t >= self.start_t + self.delay:
... # 省略部分代码 ,smoothing 判断
self.refresh(lock_args=self.lock_args)
... # 省略部分代码, dynamic_miniters 判断
self.last_print_n = self.n
self.last_print_t = cur_t
return True
update
方法中同样是对于迭代次数和迭代时间的判断,如果满足条件,则调用refresh
方法,并且更新last_print_n
和last_print_t
。我们来分析一下refresh
方法实现。
refresh函数实现
class tqdm(Comparable):
def refresh(self, nolock=False, lock_args=None):
if self.disable:
return
if not nolock:
if lock_args:
if not self._lock.acquire(*lock_args):
return False
else:
self._lock.acquire()
self.display()
if not nolock:
self._lock.release()
return True
refresh
方法中,代码比较简单,主要是对于self._lock
的获取和释放,以及调用display
方法。我们来分析一下display
方法实现。
display函数实现
class tqdm(Comparable):
def display(self, msg=None, pos=None):
... # 省略部分代码, 主要是对于msg和pos的处理
if pos:
self.moveto(pos)
self.sp(self.__str__() if msg is None else msg)
if pos:
self.moveto(-pos)
return True
display
方法中,主要是对于msg
和pos
的处理,然后调用sp
方法,sp
的参数是__str__
方法的返回值或msg
的值。此外,在display
方法中,还调用了moveto
方法,这个方法的作用是移动光标,这个方法一般用作多线程的时候,用来控制多个进度条的位置,这里我们不做分析。
我们先来看一下sp
方法的实现。
sp函数实现
if not gui:
self.sp = self.status_printer(self.fp)
sp
函数中,主要是对于gui
的判断,如果gui
为False
,则调用status_printer
方法,这个方法的作用是初始化屏幕打印器,我们来分析一下status_printer
方法的实现。
class tqdm(Comparable):
def status_printer(file):
fp = file
fp_flush = getattr(fp, 'flush', lambda: None) # pragma: no cover
if fp in (sys.stderr, sys.stdout):
getattr(sys.stderr, 'flush', lambda: None)()
getattr(sys.stdout, 'flush', lambda: None)()
def fp_write(s):
fp.write(str(s))
fp_flush()
last_len = [0]
def print_status(s):
len_s = disp_len(s)
fp_write('\r' + s + (' ' * max(last_len[0] - len_s, 0)))
last_len[0] = len_s
return print_status
status_printer
方法中,主要是对于fp
的判断,如果fp
是sys.stderr
或sys.stdout
,则调用flush
方法,然后定义了fp_write
方法,这个方法的作用是将字符串写入到fp
中,然后定义了print_status
方法,这个方法的作用是打印字符串。
这里涉及到python的一个特性,就是函数的嵌套,我们可以在一个函数中定义另一个函数,这个函数可以访问外部函数的变量,这里的fp
和last_len
就是外部函数的变量,print_status
函数可以访问这两个变量。这个和CPP中的lambda表达式类似,有兴趣的同学可以自行了解。
self.sp(self.__str__() if msg is None else msg)
这一句代码等价于
print_status(self.__str__() if msg is None else msg)
到这里,我们已经分析完了display
方法的实现,现在来分析一下__str__
方法的实现。
str 实现
__str__
是python中的一个magic method,当我们使用print
函数打印一个对象的时候,会调用这个方法,这个方法的返回值就是打印的字符串。我们来看一下tqdm
类中__str__
方法的实现。
class tqdm(Comparable):
def __str__(self):
return self.format_meter(**self.format_dict)
__str__
方法中只有一行代码,调用了format_meter
方法,这个方法的作用是格式化字符串,我们来分析一下format_meter
方法的实现。
class tqdm(Comparable):
def format_meter(n, total, elapsed, ncols=None, prefix='', ascii=False, unit='it',unit_scale=False, rate=None, bar_format=None, postfix=None,
unit_divisor=1000, initial=0, colour=None, **extra_kwargs):
if total and n >= (total + 0.5):
total = None
if unit_scale and unit_scale not in (True, 1):
if total:
total *= unit_scale
n *= unit_scale
if rate:
rate *= unit_scale
unit_scale = False
elapsed_str = tqdm.format_interval(elapsed)
if rate is None and elapsed:
rate = (n - initial) / elapsed
inv_rate = 1 / rate if rate else None
format_sizeof = tqdm.format_sizeof
rate_noinv_fmt = ((format_sizeof(rate) if unit_scale else
'{0:5.2f}'.format(rate)) if rate else '?') + unit + '/s'
rate_inv_fmt = (
(format_sizeof(inv_rate) if unit_scale else '{0:5.2f}'.format(inv_rate))
if inv_rate else '?') + 's/' + unit
rate_fmt = rate_inv_fmt if inv_rate and inv_rate > 1 else rate_noinv_fmt
if unit_scale:
n_fmt = format_sizeof(n, divisor=unit_divisor)
total_fmt = format_sizeof(total, divisor=unit_divisor) if total is not None else '?'
else:
n_fmt = str(n)
total_fmt = str(total) if total is not None else '?'
try:
postfix = ', ' + postfix if postfix else ''
except TypeError:
pass
remaining = (total - n) / rate if rate and total else 0
remaining_str = tqdm.format_interval(remaining) if rate else '?'
try:
eta_dt = (datetime.now() + timedelta(seconds=remaining)
if rate and total else datetime.utcfromtimestamp(0))
except OverflowError:
eta_dt = datetime.max
if prefix:
bool_prefix_colon_already = (prefix[-2:] == ": ")
l_bar = prefix if bool_prefix_colon_already else prefix + ": "
else:
l_bar = ''
r_bar = f'| {n_fmt}/{total_fmt} [{elapsed_str}<{remaining_str}, {rate_fmt}{postfix}]'
format_dict = {
'n': n, 'n_fmt': n_fmt, 'total': total, 'total_fmt': total_fmt,
'elapsed': elapsed_str, 'elapsed_s': elapsed,
'ncols': ncols, 'desc': prefix or '', 'unit': unit,
'rate': inv_rate if inv_rate and inv_rate > 1 else rate,
'rate_fmt': rate_fmt, 'rate_noinv': rate,
'rate_noinv_fmt': rate_noinv_fmt, 'rate_inv': inv_rate,
'rate_inv_fmt': rate_inv_fmt,
'postfix': postfix, 'unit_divisor': unit_divisor,
'colour': colour,
'remaining': remaining_str, 'remaining_s': remaining,
'l_bar': l_bar, 'r_bar': r_bar, 'eta': eta_dt,
**extra_kwargs}
if total:
frac = n / total
percentage = frac * 100
l_bar += '{0:3.0f}%|'.format(percentage)
if ncols == 0:
return l_bar[:-1] + r_bar[1:]
format_dict.update(l_bar=l_bar)
if bar_format:
format_dict.update(percentage=percentage)
if not prefix:
bar_format = bar_format.replace("{desc}: ", '')
else:
bar_format = "{l_bar}{bar}{r_bar}"
full_bar = FormatReplace()
nobar = bar_format.format(bar=full_bar, **format_dict)
if not full_bar.format_called:
return nobar
full_bar = Bar(frac,
max(1, ncols - disp_len(nobar)) if ncols else 10,
charset=Bar.ASCII if ascii is True else ascii or Bar.UTF,
colour=colour)
if not _is_ascii(full_bar.charset) and _is_ascii(bar_format):
bar_format = str(bar_format)
res = bar_format.format(bar=full_bar, **format_dict)
return disp_trim(res, ncols) if ncols else res
elif bar_format:
l_bar += '|'
format_dict.update(l_bar=l_bar, percentage=0)
full_bar = FormatReplace()
nobar = bar_format.format(bar=full_bar, **format_dict)
if not full_bar.format_called:
return nobar
full_bar = Bar(0,
max(1, ncols - disp_len(nobar)) if ncols else 10,
charset=Bar.BLANK, colour=colour)
res = bar_format.format(bar=full_bar, **format_dict)
return disp_trim(res, ncols) if ncols else res
else:
return (f'{(prefix + ": ") if prefix else ""}'
f'{n_fmt}{unit} [{elapsed_str}, {rate_fmt}{postfix}]')
format_meter
方法中的代码比较多,但是逻辑比较清晰,基本上都是对参数的格式化,不再一一赘述,我们来看一下这些参数的含义:
参数名 | 含义 |
---|---|
n | 当前迭代次数 |
total | 总迭代次数 |
elapsed | 已经迭代的时间 |
ncols | 进度条的宽度 |
prefix | 进度条前缀 |
ascii | 进度条的字符集 |
unit | 进度条的单位 |
unit_scale | 进度条的单位是否缩放 |
rate | 进度条的速度 |
bar_format | 进度条的格式 |
postfix | 进度条后缀 |
unit_divisor | 进度条单位的除数 |
initial | 进度条的初始值 |
colour | 进度条的颜色 |
extra_kwargs | 其他参数 |
format_meter
的作用就是根据这些参数,格式化字符串,然后返回这个字符串。
代码后半部分,对于total
和bar_format
的判断,是对于进度条的处理。
if total:
... # 省略部分代码
full_bar = Bar(frac,
max(1, ncols - disp_len(nobar)) if ncols else 10,
charset=Bar.ASCII if ascii is True else ascii or Bar.UTF,
colour=colour)
if not _is_ascii(full_bar.charset) and _is_ascii(bar_format):
bar_format = str(bar_format)
res = bar_format.format(bar=full_bar, **format_dict)
return disp_trim(res, ncols) if ncols else res
elif bar_format:
... # 省略部分代码
full_bar = Bar(0,
max(1, ncols - disp_len(nobar)) if ncols else 10,
charset=Bar.BLANK, colour=colour)
res = bar_format.format(bar=full_bar, **format_dict)
return disp_trim(res, ncols) if ncols else res
else:
return (f'{(prefix + ": ") if prefix else ""}'
f'{n_fmt}{unit} [{elapsed_str}, {rate_fmt}{postfix}]')
这里的Bar
类是一个生成进度条的类,我们来分析一下这个类的实现。
Bar类实现
class Bar(object):
ASCII = " 123456789#"
UTF = u" " + u''.join(map(chr, range(0x258F, 0x2587, -1)))
BLANK = " "
COLOUR_RESET = '\x1b[0m'
COLOUR_RGB = '\x1b[38;2;%d;%d;%dm'
COLOURS = {'BLACK': '\x1b[30m', 'RED': '\x1b[31m', 'GREEN': '\x1b[32m',
'YELLOW': '\x1b[33m', 'BLUE': '\x1b[34m', 'MAGENTA': '\x1b[35m',
'CYAN': '\x1b[36m', 'WHITE': '\x1b[37m'}
def __init__(self, frac, default_len=10, charset=UTF, colour=None):
if not 0 <= frac <= 1:
warn("clamping frac to range [0, 1]", TqdmWarning, stacklevel=2)
frac = max(0, min(1, frac))
assert default_len > 0
self.frac = frac
self.default_len = default_len
self.charset = charset
self.colour = colour
@property
def colour(self):
return self._colour
@colour.setter
def colour(self, value):
if not value:
self._colour = None
return
try:
if value.upper() in self.COLOURS:
self._colour = self.COLOURS[value.upper()]
elif value[0] == '#' and len(value) == 7:
self._colour = self.COLOUR_RGB % tuple(
int(i, 16) for i in (value[1:3], value[3:5], value[5:7]))
else:
raise KeyError
except (KeyError, AttributeError):
warn("Unknown colour (%s); valid choices: [hex (#00ff00), %s]" % (
value, ", ".join(self.COLOURS)),
TqdmWarning, stacklevel=2)
self._colour = None
def __format__(self, format_spec):
if format_spec:
_type = format_spec[-1].lower()
try:
charset = {'a': self.ASCII, 'u': self.UTF, 'b': self.BLANK}[_type]
except KeyError:
charset = self.charset
else:
format_spec = format_spec[:-1]
if format_spec:
N_BARS = int(format_spec)
if N_BARS < 0:
N_BARS += self.default_len
else:
N_BARS = self.default_len
else:
charset = self.charset
N_BARS = self.default_len
nsyms = len(charset) - 1
bar_length, frac_bar_length = divmod(int(self.frac * N_BARS * nsyms), nsyms)
res = charset[-1] * bar_length
if bar_length < N_BARS: # whitespace padding
res = res + charset[frac_bar_length] + charset[0] * (N_BARS - bar_length - 1)
return self.colour + res + self.COLOUR_RESET if self.colour else res
Bar
类中,定义了一些常量,然后定义了__init__
方法,这个方法的作用是初始化进度条,然后定义了colour
属性,这个属性的作用是设置进度条的颜色,然后定义了__format__
方法,这个方法的作用是格式化进度条,然后返回这个进度条的字符串。
到这里,我们已经分析完了tqdm
类的实现,我们来总结一下tqdm
的调用过程。
__iter__ -> update -> refresh -> display -> print_status
__str__ -> format_meter -> Bar
print_status(self.__str__() if msg is None else msg)
tqdm cpp实现
分析完了python版本的tqdm
,我们结合之前的tqdm.cpp
的实现,来扩展一下tqdm.cpp
的功能,实现一个简易cpp版本的tqdm
的功能。
在python的实现中,tqdm类中有很多属性用于记录迭代的状态,这里我们只记录上次打印的时间,上次打印的迭代次数,当前迭代次数,总迭代次数,然后定义一个process
结构体,用于记录这些状态。
struct process
{
using time_point = std::chrono::_V2::system_clock::time_point;
time_point start;
time_point last_print_t; // 上次打印的时间
size_t last_print_n = 0; // 上次打印的迭代次数
size_t iterations = 0; // 当前迭代次数
size_t total = 0; // 总迭代次数
double rate = 0; // 速率
size_t last_len; // 上次打印的长度
};
为了更灵活的控制进度条显示,我们需要修改Tqdm
类的定义和构造函数,为构造函数添加一个默认参数Params()
,这样就可以实现更灵活的控制进度条的显示。
template <typename _Iterator>
class Tqdm : public IteratorWrapper<_Iterator>
{
private:
using TQDM_IT = IteratorWrapper<_Iterator>;
_Iterator m_endIt;
mutable process m_process;
Params m_params;
public:
explicit Tqdm(_Iterator begin, _Iterator end, const Params &s = Params())
: TQDM_IT(begin), m_endIt(end), m_params(s), m_process()
{
m_process.start = std::chrono::system_clock::now();
m_process.last_print_t = m_process.start;
m_process.total = std::distance(begin, end);
}
... // 省略部分代码
};
其他的类和模板函数同样添加Params()
,这里不再赘述。
上一次我们分析了Tqdm
的打印功能是在_incr()
函数中实现的,我们只需要修改这个函数,就可以实现进度条的打印功能。
为了更灵活的拓展进度条的功能,我们可以将打印功能抽离出来,用一个单例类或者全局对象来实现,这里我们使用单例类来实现,_incr()
只需要通过单例类来打印进度条即可。
virtual void _incr() const override
{
TQDM_IT::_incr();
printProgress::getInstance()(m_process, m_params);
}
根据上文分析的python代码,我们需要在这个单例类中实现参数的格式化以及进度条的显示。
printProgress
类的定义如下:
class printProgress
{
private:
std::function<void(process &p, const Params &s)> m_funcPrintProgress;
printProgress()
: m_funcPrintProgress(defaultPrintProgress)
{
}
printProgress(const printProgress &) = delete;
printProgress &operator=(const printProgress &) = delete;
printProgress(printProgress &&) = delete;
printProgress &operator=(printProgress &&) = delete;
~printProgress() = default;
public:
static printProgress &getInstance()
{
static printProgress instance;
return instance;
}
public:
void setPrintProgress(std::function<void(process &p, const Params &s)> printProgress)
{
m_funcPrintProgress = printProgress;
}
void resetPrintProgress()
{
m_funcPrintProgress = defaultPrintProgress;
}
void operator()(process &p, const Params &s)
{
m_funcPrintProgress(p, s);
}
static void defaultPrintProgress(process &p, const Params &s);
static std::string formatMeter(process p, Params s);
static void display(process &p, const Params &s, int pos = 0);
};
printProgress
类中包含一个std::function
对象,我们可以通过这个对象来设置打印进度条的函数,这样就可以实现更灵活的控制进度条的显示。
默认的打印进度条的函数是defaultPrintProgress
,这个函数的实现如下:
static void defaultPrintProgress(process &p, const Params &s)
{
if (s.disable)
return;
p.iterations++;
if (p.iterations - p.last_print_n >= s.miniters)
{
auto now = std::chrono::system_clock::now();
auto dt = std::chrono::duration_cast<std::chrono::milliseconds>(now - p.last_print_t).count();
if (dt >= s.mininterval)
{
if (p.iterations - p.last_print_n >= s.miniters)
{
auto dn = p.iterations - p.last_print_n;
p.rate = dn / (dt / 1000.0);
display(p, s);
}
p.last_print_n = p.iterations;
p.last_print_t = now;
}
}
}
defaultPrintProgress
的实现是在python版本中的__iter__ -> update -> refresh
三个方法的实现的基础上的简化。
在这个函数中,我们只需要判断是否满足打印条件,如果满足,则调用display
方法,这个方法的实现如下:
static void display(process &p, const Params &s, int pos = 0)
{
std::string msg = formatMeter(p, s);
std::cout<<"\r"<<msg;
std::cout.flush();
p.last_len = msg.length();
}
display
方法的代码很少,主要功能在于调用formatMeter
格式化参数,并将结果显示出来。
formatMeter
方法的实现如下:
static std::string formatMeter(const process &p, const Params &s)
{
size_t total = p.total;
double rate = p.rate;
size_t n = p.iterations;
int unit_scale = s.unit_scale;
if (unit_scale > 0)
{
if (total > 0)
{
total *= unit_scale;
}
n *= unit_scale;
if (abs(rate) > 1e-5)
{
rate *= unit_scale;
}
unit_scale = 0;
}
std::string rate_fmt;
if (abs(rate) > 1e-5)
{
rate_fmt = format_sizeof(rate, s.unit) + "/s";
}
else
{
rate_fmt = format_sizeof(0, s.unit) + "/s";
}
double frac = (double)n / total;
double percentage = frac * 100;
std::string l_bar = fmt::vformat("{} {:.1f}% [", fmt::make_format_args(s.desc, percentage));
int fill_width = std::max(0, int(s.ncols * frac) - 1);
std::string fill_str = std::string(fill_width, s.ascii[s.ascii.length() - 1]);
fill_str += s.ascii[n % 11];
if (fill_str.length() < s.ncols)
{
fill_str += std::string(s.ncols - fill_str.length(), ' ');
}
else
{
fill_str[fill_str.length() - 1] = s.ascii[s.ascii.length() - 1];
}
size_t remaining = (total - n) / rate;
size_t used_time = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now() - p.start).count();
std::string remaining_str = formatTime(remaining);
std::string used_time_str = formatTime(used_time);
std::string bar = fmt::vformat("{} ]", fmt::make_format_args(fill_str));
std::string r_bar = fmt::vformat("{}/{} [{}/{}] {}", fmt::make_format_args(n, total, used_time_str, remaining_str, rate_fmt));
std::string msg = fmt::vformat("{}{}{}", fmt::make_format_args(l_bar, bar, r_bar));
if (n == total)
{
msg += !s.leave ? "\n" : "\r";
}
if (p.last_len > msg.length())
{
msg += std::string(p.last_len - msg.length(), ' ');
}
return msg;
}
formatMeter
方法的实现是在python版本中的format_meter
方法的实现的基础上的简化。
这里我们使用了fmt
库来格式化字符串,这个库的是C++20中的std::format
的实现,这个库的使用方法可以参考fmtlib。
formatMeter
方法只处理了部分参数,实现的功能比较简单,我们在后续的版本中会对这个方法进行优化。
主要逻辑如下:
首先判断是否需要单位缩放,如果需要,则对参数进行缩放。
然后计算进度条的百分比,时间等参数。
格式为l_bar + bar + r_bar
,
l_bar
是进度条的前缀+进度条的百分比+[
bar
是进度条
r_bar
是当前迭代次数+总迭代次数+已用时间+剩余时间+速率
最后将这些参数拼接起来,就是最终的结果。
以上就是一个简易版的tqdm
的实现,我们可以通过这个类来实现进度条的显示,这个类的使用方法如下:
int main()
{
std::vector<int> v(3000,1);
for(auto i:tqdm::tqdm(v))
{
std::this_thread::sleep_for(std::chrono::milliseconds(10));
i++;
}
}
运行结果如下:
12.7% [#6 ]38000/300000 [0:3/0:26] 10.0Kit/s
之后我们会对这个类进行优化,实现更多的功能。