C++ 新手进阶之tqdm4cpp拓展

tqdm.cpp 功能拓展

之前我们分析了tqdm.cpp的实现,但是这个库实现的功能并不完善,这里我们对这个库进行功能拓展,实现一个简易版的tqdm
首先我们来分析一下python版本的tqdm的实现,然后结合之前的tqdm.cpp的实现,来扩展一下tqdm.cpp的功能,实现一个简易cpp版本的tqdm的功能。

tqdm python实现

tqdm的主要功能实现在std.py 中的tqdm类中,现在我们来分析一下tqdm类的实现。
在使用tqdm的时候,我们一般会这样使用:

from tqdm import tqdm
for i in tqdm(range(100)):
    pass

根据python的语法,我们可以知道,tqdm是一个类,tqdm(range(100))返回一个可迭代对象,我们可以使用for循环来遍历这个可迭代对象,也可以使用next函数来获取这个可迭代对象的下一个元素。我们来看一下tqdm类的关于可迭代对象的实现:__iter__

class tqdm(Comparable):
    ... # 省略部分代码

    def __iter__(self):
       
        iterable = self.iterable

        if self.disable:
            for obj in iterable:
                yield obj
            return

        ... # 省略部分代码

        try:
            for obj in iterable:
                yield obj
                n += 1

                if n - last_print_n >= self.miniters:
                    cur_t = time()
                    dt = cur_t - last_print_t
                    if dt >= mininterval and cur_t >= min_start_t:
                        self.update(n - last_print_n)
                        last_print_n = self.last_print_n
                        last_print_t = self.last_print_t
        finally:
            self.n = n
            self.close()

__iter__实现中我们可以看到,这个方法通过yield关键字返回一个可迭代对象,这个可迭代对象的每一个元素都是iterable的元素,同时,每次迭代的时候,判断迭代次数是否大于miniters,迭代时间是否大于mininterval,如果满足条件,则调用update方法,那update方法应该有关于进度条的实现,我们来分析一下update方法实现。

update函数实现

class tqdm(Comparable):
    def update(self, n=1):
        ... # 省略部分代码,条件判断

        if self.n - self.last_print_n >= self.miniters:
            cur_t = self._time()
            dt = cur_t - self.last_print_t
            if dt >= self.mininterval and cur_t >= self.start_t + self.delay:
                ... # 省略部分代码 ,smoothing 判断

                self.refresh(lock_args=self.lock_args)
                
                ... # 省略部分代码, dynamic_miniters 判断

                self.last_print_n = self.n
                self.last_print_t = cur_t
                return True

update方法中同样是对于迭代次数和迭代时间的判断,如果满足条件,则调用refresh方法,并且更新last_print_nlast_print_t。我们来分析一下refresh方法实现。

refresh函数实现

class tqdm(Comparable):
    def refresh(self, nolock=False, lock_args=None):
        if self.disable:
            return

        if not nolock:
            if lock_args:
                if not self._lock.acquire(*lock_args):
                    return False
            else:
                self._lock.acquire()
        self.display()
        if not nolock:
            self._lock.release()
        return True

refresh方法中,代码比较简单,主要是对于self._lock的获取和释放,以及调用display方法。我们来分析一下display方法实现。

display函数实现

class tqdm(Comparable):
    def display(self, msg=None, pos=None):
        ... # 省略部分代码, 主要是对于msg和pos的处理

        if pos:
            self.moveto(pos)
        self.sp(self.__str__() if msg is None else msg)
        if pos:
            self.moveto(-pos)
        return True

display方法中,主要是对于msgpos的处理,然后调用sp方法,sp的参数是__str__方法的返回值或msg的值。此外,在display方法中,还调用了moveto方法,这个方法的作用是移动光标,这个方法一般用作多线程的时候,用来控制多个进度条的位置,这里我们不做分析。
我们先来看一下sp方法的实现。

sp函数实现

    if not gui:
        self.sp = self.status_printer(self.fp)

sp函数中,主要是对于gui的判断,如果guiFalse,则调用status_printer方法,这个方法的作用是初始化屏幕打印器,我们来分析一下status_printer方法的实现。

class tqdm(Comparable):
    def status_printer(file):
        fp = file
        fp_flush = getattr(fp, 'flush', lambda: None)  # pragma: no cover
        if fp in (sys.stderr, sys.stdout):
            getattr(sys.stderr, 'flush', lambda: None)()
            getattr(sys.stdout, 'flush', lambda: None)()

        def fp_write(s):
            fp.write(str(s))
            fp_flush()

        last_len = [0]

        def print_status(s):
            len_s = disp_len(s)
            fp_write('\r' + s + (' ' * max(last_len[0] - len_s, 0)))
            last_len[0] = len_s

        return print_status

status_printer方法中,主要是对于fp的判断,如果fpsys.stderrsys.stdout,则调用flush方法,然后定义了fp_write方法,这个方法的作用是将字符串写入到fp中,然后定义了print_status方法,这个方法的作用是打印字符串。
这里涉及到python的一个特性,就是函数的嵌套,我们可以在一个函数中定义另一个函数,这个函数可以访问外部函数的变量,这里的fplast_len就是外部函数的变量,print_status函数可以访问这两个变量。这个和CPP中的lambda表达式类似,有兴趣的同学可以自行了解。

    self.sp(self.__str__() if msg is None else msg)

这一句代码等价于

    print_status(self.__str__() if msg is None else msg)

到这里,我们已经分析完了display方法的实现,现在来分析一下__str__方法的实现。

str 实现

__str__ 是python中的一个magic method,当我们使用print函数打印一个对象的时候,会调用这个方法,这个方法的返回值就是打印的字符串。我们来看一下tqdm类中__str__方法的实现。

class tqdm(Comparable):
    def __str__(self):
        return self.format_meter(**self.format_dict)

__str__方法中只有一行代码,调用了format_meter方法,这个方法的作用是格式化字符串,我们来分析一下format_meter方法的实现。

class tqdm(Comparable):
    def format_meter(n, total, elapsed, ncols=None, prefix='', ascii=False, unit='it',unit_scale=False, rate=None, bar_format=None, postfix=None,
                unit_divisor=1000, initial=0, colour=None, **extra_kwargs):
   
    if total and n >= (total + 0.5): 
        total = None

    if unit_scale and unit_scale not in (True, 1):
        if total:
            total *= unit_scale
        n *= unit_scale
        if rate:
            rate *= unit_scale  
        unit_scale = False

    elapsed_str = tqdm.format_interval(elapsed)

    if rate is None and elapsed:
        rate = (n - initial) / elapsed
    inv_rate = 1 / rate if rate else None
    format_sizeof = tqdm.format_sizeof
    rate_noinv_fmt = ((format_sizeof(rate) if unit_scale else
                        '{0:5.2f}'.format(rate)) if rate else '?') + unit + '/s'
    rate_inv_fmt = (
        (format_sizeof(inv_rate) if unit_scale else '{0:5.2f}'.format(inv_rate))
        if inv_rate else '?') + 's/' + unit
    rate_fmt = rate_inv_fmt if inv_rate and inv_rate > 1 else rate_noinv_fmt

    if unit_scale:
        n_fmt = format_sizeof(n, divisor=unit_divisor)
        total_fmt = format_sizeof(total, divisor=unit_divisor) if total is not None else '?'
    else:
        n_fmt = str(n)
        total_fmt = str(total) if total is not None else '?'

    try:
        postfix = ', ' + postfix if postfix else ''
    except TypeError:
        pass

    remaining = (total - n) / rate if rate and total else 0
    remaining_str = tqdm.format_interval(remaining) if rate else '?'
    try:
        eta_dt = (datetime.now() + timedelta(seconds=remaining)
                    if rate and total else datetime.utcfromtimestamp(0))
    except OverflowError:
        eta_dt = datetime.max

    
    if prefix:
        
        bool_prefix_colon_already = (prefix[-2:] == ": ")
        l_bar = prefix if bool_prefix_colon_already else prefix + ": "
    else:
        l_bar = ''

    r_bar = f'| {n_fmt}/{total_fmt} [{elapsed_str}<{remaining_str}, {rate_fmt}{postfix}]'

    
    format_dict = {
       
        'n': n, 'n_fmt': n_fmt, 'total': total, 'total_fmt': total_fmt,
        'elapsed': elapsed_str, 'elapsed_s': elapsed,
        'ncols': ncols, 'desc': prefix or '', 'unit': unit,
        'rate': inv_rate if inv_rate and inv_rate > 1 else rate,
        'rate_fmt': rate_fmt, 'rate_noinv': rate,
        'rate_noinv_fmt': rate_noinv_fmt, 'rate_inv': inv_rate,
        'rate_inv_fmt': rate_inv_fmt,
        'postfix': postfix, 'unit_divisor': unit_divisor,
        'colour': colour,
        
        'remaining': remaining_str, 'remaining_s': remaining,
        'l_bar': l_bar, 'r_bar': r_bar, 'eta': eta_dt,
        **extra_kwargs}

   
    if total:
       
        frac = n / total
        percentage = frac * 100

        l_bar += '{0:3.0f}%|'.format(percentage)

        if ncols == 0:
            return l_bar[:-1] + r_bar[1:]

        format_dict.update(l_bar=l_bar)
        if bar_format:
            format_dict.update(percentage=percentage)

            
            if not prefix:
                bar_format = bar_format.replace("{desc}: ", '')
        else:
            bar_format = "{l_bar}{bar}{r_bar}"

        full_bar = FormatReplace()
        nobar = bar_format.format(bar=full_bar, **format_dict)
        if not full_bar.format_called:
            return nobar  

        
        full_bar = Bar(frac,
                        max(1, ncols - disp_len(nobar)) if ncols else 10,
                        charset=Bar.ASCII if ascii is True else ascii or Bar.UTF,
                        colour=colour)
        if not _is_ascii(full_bar.charset) and _is_ascii(bar_format):
            bar_format = str(bar_format)
        res = bar_format.format(bar=full_bar, **format_dict)
        return disp_trim(res, ncols) if ncols else res

    elif bar_format:
        l_bar += '|'
        format_dict.update(l_bar=l_bar, percentage=0)
        full_bar = FormatReplace()
        nobar = bar_format.format(bar=full_bar, **format_dict)
        if not full_bar.format_called:
            return nobar
        full_bar = Bar(0,
                        max(1, ncols - disp_len(nobar)) if ncols else 10,
                        charset=Bar.BLANK, colour=colour)
        res = bar_format.format(bar=full_bar, **format_dict)
        return disp_trim(res, ncols) if ncols else res
    else:
        
        return (f'{(prefix + ": ") if prefix else ""}'
                f'{n_fmt}{unit} [{elapsed_str}, {rate_fmt}{postfix}]')

format_meter方法中的代码比较多,但是逻辑比较清晰,基本上都是对参数的格式化,不再一一赘述,我们来看一下这些参数的含义:

参数名含义
n当前迭代次数
total总迭代次数
elapsed已经迭代的时间
ncols进度条的宽度
prefix进度条前缀
ascii进度条的字符集
unit进度条的单位
unit_scale进度条的单位是否缩放
rate进度条的速度
bar_format进度条的格式
postfix进度条后缀
unit_divisor进度条单位的除数
initial进度条的初始值
colour进度条的颜色
extra_kwargs其他参数

format_meter的作用就是根据这些参数,格式化字符串,然后返回这个字符串。
代码后半部分,对于totalbar_format的判断,是对于进度条的处理。

    if total:
       
        ... # 省略部分代码
        full_bar = Bar(frac,
                        max(1, ncols - disp_len(nobar)) if ncols else 10,
                        charset=Bar.ASCII if ascii is True else ascii or Bar.UTF,
                        colour=colour)
        if not _is_ascii(full_bar.charset) and _is_ascii(bar_format):
            bar_format = str(bar_format)
        res = bar_format.format(bar=full_bar, **format_dict)
        return disp_trim(res, ncols) if ncols else res

    elif bar_format:
        ... # 省略部分代码
        full_bar = Bar(0,
                        max(1, ncols - disp_len(nobar)) if ncols else 10,
                        charset=Bar.BLANK, colour=colour)
        res = bar_format.format(bar=full_bar, **format_dict)
        return disp_trim(res, ncols) if ncols else res
    else:
        
        return (f'{(prefix + ": ") if prefix else ""}'
                f'{n_fmt}{unit} [{elapsed_str}, {rate_fmt}{postfix}]')

这里的Bar类是一个生成进度条的类,我们来分析一下这个类的实现。

Bar类实现

class Bar(object):
    ASCII = " 123456789#"
    UTF = u" " + u''.join(map(chr, range(0x258F, 0x2587, -1)))
    BLANK = "  "
    COLOUR_RESET = '\x1b[0m'
    COLOUR_RGB = '\x1b[38;2;%d;%d;%dm'
    COLOURS = {'BLACK': '\x1b[30m', 'RED': '\x1b[31m', 'GREEN': '\x1b[32m',
               'YELLOW': '\x1b[33m', 'BLUE': '\x1b[34m', 'MAGENTA': '\x1b[35m',
               'CYAN': '\x1b[36m', 'WHITE': '\x1b[37m'}

    def __init__(self, frac, default_len=10, charset=UTF, colour=None):
        if not 0 <= frac <= 1:
            warn("clamping frac to range [0, 1]", TqdmWarning, stacklevel=2)
            frac = max(0, min(1, frac))
        assert default_len > 0
        self.frac = frac
        self.default_len = default_len
        self.charset = charset
        self.colour = colour

    @property
    def colour(self):
        return self._colour

    @colour.setter
    def colour(self, value):
        if not value:
            self._colour = None
            return
        try:
            if value.upper() in self.COLOURS:
                self._colour = self.COLOURS[value.upper()]
            elif value[0] == '#' and len(value) == 7:
                self._colour = self.COLOUR_RGB % tuple(
                    int(i, 16) for i in (value[1:3], value[3:5], value[5:7]))
            else:
                raise KeyError
        except (KeyError, AttributeError):
            warn("Unknown colour (%s); valid choices: [hex (#00ff00), %s]" % (
                 value, ", ".join(self.COLOURS)),
                 TqdmWarning, stacklevel=2)
            self._colour = None

    def __format__(self, format_spec):
        if format_spec:
            _type = format_spec[-1].lower()
            try:
                charset = {'a': self.ASCII, 'u': self.UTF, 'b': self.BLANK}[_type]
            except KeyError:
                charset = self.charset
            else:
                format_spec = format_spec[:-1]
            if format_spec:
                N_BARS = int(format_spec)
                if N_BARS < 0:
                    N_BARS += self.default_len
            else:
                N_BARS = self.default_len
        else:
            charset = self.charset
            N_BARS = self.default_len

        nsyms = len(charset) - 1
        bar_length, frac_bar_length = divmod(int(self.frac * N_BARS * nsyms), nsyms)

        res = charset[-1] * bar_length
        if bar_length < N_BARS:  # whitespace padding
            res = res + charset[frac_bar_length] + charset[0] * (N_BARS - bar_length - 1)
        return self.colour + res + self.COLOUR_RESET if self.colour else res

Bar类中,定义了一些常量,然后定义了__init__方法,这个方法的作用是初始化进度条,然后定义了colour属性,这个属性的作用是设置进度条的颜色,然后定义了__format__方法,这个方法的作用是格式化进度条,然后返回这个进度条的字符串。

到这里,我们已经分析完了tqdm类的实现,我们来总结一下tqdm的调用过程。

__iter__ -> update -> refresh -> display -> print_status 
__str__ -> format_meter -> Bar
print_status(self.__str__() if msg is None else msg)

tqdm cpp实现

分析完了python版本的tqdm,我们结合之前的tqdm.cpp的实现,来扩展一下tqdm.cpp的功能,实现一个简易cpp版本的tqdm的功能。

在python的实现中,tqdm类中有很多属性用于记录迭代的状态,这里我们只记录上次打印的时间,上次打印的迭代次数,当前迭代次数,总迭代次数,然后定义一个process结构体,用于记录这些状态。

    struct process
    {
        using time_point = std::chrono::_V2::system_clock::time_point;
        time_point start;
        time_point last_print_t; // 上次打印的时间
        size_t last_print_n = 0; // 上次打印的迭代次数
        size_t iterations = 0;   // 当前迭代次数
        size_t total = 0;        // 总迭代次数
        double rate = 0;         // 速率
        size_t last_len;        // 上次打印的长度
    };

为了更灵活的控制进度条显示,我们需要修改Tqdm类的定义和构造函数,为构造函数添加一个默认参数Params(),这样就可以实现更灵活的控制进度条的显示。

    template <typename _Iterator>
    class Tqdm : public IteratorWrapper<_Iterator>
    {
    private:
        using TQDM_IT = IteratorWrapper<_Iterator>;
        _Iterator m_endIt;
        mutable process m_process;
        Params m_params;
    public:
        explicit Tqdm(_Iterator begin, _Iterator end, const Params &s = Params())
            : TQDM_IT(begin), m_endIt(end), m_params(s), m_process()
        {
            m_process.start = std::chrono::system_clock::now();
            m_process.last_print_t = m_process.start;
            m_process.total = std::distance(begin, end);
        }
        ... // 省略部分代码
    };

其他的类和模板函数同样添加Params(),这里不再赘述。

上一次我们分析了Tqdm的打印功能是在_incr()函数中实现的,我们只需要修改这个函数,就可以实现进度条的打印功能。
为了更灵活的拓展进度条的功能,我们可以将打印功能抽离出来,用一个单例类或者全局对象来实现,这里我们使用单例类来实现,_incr()只需要通过单例类来打印进度条即可。

    virtual void _incr() const override
    {
        TQDM_IT::_incr();
        printProgress::getInstance()(m_process, m_params);
    }

根据上文分析的python代码,我们需要在这个单例类中实现参数的格式化以及进度条的显示。
printProgress类的定义如下:

    class printProgress
    {

    private:
        std::function<void(process &p, const Params &s)> m_funcPrintProgress;

        printProgress()
            : m_funcPrintProgress(defaultPrintProgress)
        {
        }
        printProgress(const printProgress &) = delete;
        printProgress &operator=(const printProgress &) = delete;
        printProgress(printProgress &&) = delete;
        printProgress &operator=(printProgress &&) = delete;
        ~printProgress() = default;

    public:
        static printProgress &getInstance()
        {
            static printProgress instance;
            return instance;
        }

    public:
        void setPrintProgress(std::function<void(process &p, const Params &s)> printProgress)
        {
            m_funcPrintProgress = printProgress;
        }
        void resetPrintProgress()
        {
            m_funcPrintProgress = defaultPrintProgress;
        }

        void operator()(process &p, const Params &s)
        {
            m_funcPrintProgress(p, s);
        }


        static void defaultPrintProgress(process &p, const Params &s);
        static std::string formatMeter(process p, Params s);
        static void display(process &p, const Params &s, int pos = 0);
       
    };

printProgress类中包含一个std::function对象,我们可以通过这个对象来设置打印进度条的函数,这样就可以实现更灵活的控制进度条的显示。
默认的打印进度条的函数是defaultPrintProgress,这个函数的实现如下:

    static void defaultPrintProgress(process &p, const Params &s)
    {

        if (s.disable)
            return;

        p.iterations++;

        if (p.iterations - p.last_print_n >= s.miniters)
        {
            auto now = std::chrono::system_clock::now();
            auto dt = std::chrono::duration_cast<std::chrono::milliseconds>(now - p.last_print_t).count();
            if (dt >= s.mininterval)
            {
                
                if (p.iterations - p.last_print_n >= s.miniters)
                {
                    auto dn = p.iterations - p.last_print_n;
                    p.rate = dn / (dt / 1000.0); 
                    
                    display(p, s);
                }

                p.last_print_n = p.iterations;
                p.last_print_t = now;
            }
        }
    }

defaultPrintProgress的实现是在python版本中的__iter__ -> update -> refresh 三个方法的实现的基础上的简化。
在这个函数中,我们只需要判断是否满足打印条件,如果满足,则调用display方法,这个方法的实现如下:

    static void display(process &p, const Params &s, int pos = 0)
    {
        std::string msg = formatMeter(p, s);
        std::cout<<"\r"<<msg;
        std::cout.flush();
        p.last_len = msg.length();
    }

display方法的代码很少,主要功能在于调用formatMeter格式化参数,并将结果显示出来。

formatMeter方法的实现如下:

    static std::string formatMeter(const process &p, const Params &s)
    {
        size_t total = p.total;
        double rate = p.rate;
        size_t n = p.iterations;
        int unit_scale = s.unit_scale;

        if (unit_scale > 0)
        {
            if (total > 0)
            {
                total *= unit_scale;
            }
            n *= unit_scale;
            if (abs(rate) > 1e-5)
            {
                rate *= unit_scale;
            }
            unit_scale = 0;
        }

        std::string rate_fmt;
        if (abs(rate) > 1e-5)
        {
            rate_fmt = format_sizeof(rate, s.unit) + "/s";
        }
        else
        {
            rate_fmt = format_sizeof(0, s.unit) + "/s";
        }

        double frac = (double)n / total;
        double percentage = frac * 100;

        std::string l_bar = fmt::vformat("{} {:.1f}% [", fmt::make_format_args(s.desc, percentage));

        int fill_width = std::max(0, int(s.ncols * frac) - 1);
        std::string fill_str = std::string(fill_width, s.ascii[s.ascii.length() - 1]);
        fill_str += s.ascii[n % 11];

        if (fill_str.length() < s.ncols)
        {
            fill_str += std::string(s.ncols - fill_str.length(), ' ');
        }
        else
        {
            fill_str[fill_str.length() - 1] = s.ascii[s.ascii.length() - 1];
        }

        size_t remaining = (total - n) / rate;
        size_t used_time = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now() - p.start).count();
        std::string remaining_str = formatTime(remaining);
        std::string used_time_str = formatTime(used_time);

        std::string bar = fmt::vformat("{} ]", fmt::make_format_args(fill_str));
        std::string r_bar = fmt::vformat("{}/{} [{}/{}] {}", fmt::make_format_args(n, total, used_time_str, remaining_str, rate_fmt));

        std::string msg = fmt::vformat("{}{}{}", fmt::make_format_args(l_bar, bar, r_bar));

        if (n == total)
        {
            msg += !s.leave ? "\n" : "\r";
        }

        if (p.last_len > msg.length())
        {
            msg += std::string(p.last_len - msg.length(), ' ');
        }

        return msg;
    }



formatMeter方法的实现是在python版本中的format_meter方法的实现的基础上的简化。
这里我们使用了fmt库来格式化字符串,这个库的是C++20中的std::format的实现,这个库的使用方法可以参考fmtlib
formatMeter方法只处理了部分参数,实现的功能比较简单,我们在后续的版本中会对这个方法进行优化。
主要逻辑如下:
首先判断是否需要单位缩放,如果需要,则对参数进行缩放。
然后计算进度条的百分比,时间等参数。
格式为l_bar + bar + r_bar
l_bar是进度条的前缀+进度条的百分比+[
bar是进度条
r_bar是当前迭代次数+总迭代次数+已用时间+剩余时间+速率
最后将这些参数拼接起来,就是最终的结果。

以上就是一个简易版的tqdm的实现,我们可以通过这个类来实现进度条的显示,这个类的使用方法如下:

    
int main()
{
    std::vector<int> v(3000,1);
    for(auto i:tqdm::tqdm(v))
    {
        std::this_thread::sleep_for(std::chrono::milliseconds(10));
        i++;
    }
    
}

运行结果如下:

    12.7% [#6                   ]38000/300000 [0:3/0:26] 10.0Kit/s

之后我们会对这个类进行优化,实现更多的功能。

  • 13
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值