【扒代码】tqdm有趣的应用


def train():
    print("Training on FSC147 train set data")
    im_ids = data_split['train']
    random.shuffle(im_ids)
    train_mae = 0
    train_rmse = 0
    train_loss = 0
    pbar = tqdm(im_ids)
    cnt = 0
    for im_id in pbar:
        cnt += 1
        anno = annotations[im_id]
        bboxes = anno['box_examples_coordinates']
        dots = np.array(anno['points'])

        rects = list()
        for bbox in bboxes:
            x1 = bbox[0][0]
            y1 = bbox[0][1]
            x2 = bbox[2][0]
            y2 = bbox[2][1]
            rects.append([y1, x1, y2, x2])

        image = Image.open('{}/{}'.format(im_dir, im_id))
        image.load()
        density_path = gt_dir + '/' + im_id.split(".jpg")[0] + ".npy"
        density = np.load(density_path).astype('float32')    
        sample = {'image':image,'lines_boxes':rects,'gt_density':density}
        sample = TransformTrain(sample)
        image, boxes,gt_density = sample['image'], sample['boxes'],sample['gt_density']

        with torch.no_grad():
            features = extract_features(resnet50_conv, image.unsqueeze(0), boxes.unsqueeze(0), MAPS, Scales)
        features.requires_grad = True
        optimizer.zero_grad()
        output = regressor(features)

        '''
        确保无论模型输出的尺寸如何变化,真实标签的物体数量都能保持一致
        使得损失函数能够正确地计算预测和真实标签之间的差异
        常见的后处理步骤,特别是在处理不规则尺寸的图像数据时
        '''

        # 双线性插值 调整图像尺寸

        #if image size isn't divisible by 8, gt size is slightly different from output size
        # 如果图像尺寸不能被8整除,则gt尺寸与输出尺寸略有不同
        if output.shape[2] != gt_density.shape[2] or output.shape[3] != gt_density.shape[3]:
            # 如果gt_density的尺寸与输出尺寸不同,则将gt_density插值到输出尺寸
            # 判断条件是两个张量的高度和宽度

            # 计算gt_density的原始像素总数
            # sum()返回张量中所有元素的和 
            # detach()返回一个新的张量,从当前计算图中分离出来 
            # item()返回张量中的一个元素(转换成标量)
            orig_count = gt_density.sum().detach().item()

            # 将 gt_density 插值到与 output 相同的尺寸。
            # size 参数指定了目标尺寸
            # mode='bilinear' 指定了使用双线性插值方法。
            gt_density = F.interpolate(gt_density, size=(output.shape[2],output.shape[3]),mode='bilinear')

            # 插值后的 gt_density 再次求和,得到插值后的非零元素总数。
            new_count = gt_density.sum().detach().item()

            if new_count > 0: 
                # 如果插值后的 gt_density 非零元素总数不为零,
                # 通过缩放因子 orig_count / new_count 调整 gt_density 的值,
                # 以保持物体数量不变。
                # 这是为了保证插值过程中物体的总数(即密度)保持一致。
                gt_density = gt_density * (orig_count / new_count)
        # 损失函数
        loss = criterion(output, gt_density)
        # 梯度反向传播
        loss.backward()
        # 参数更新
        optimizer.step()
        train_loss += loss.item()
        pred_cnt = torch.sum(output).item()
        gt_cnt = torch.sum(gt_density).item()
        cnt_err = abs(pred_cnt - gt_cnt)
        train_mae += cnt_err
        train_rmse += cnt_err ** 2

    train_loss = train_loss / len(im_ids)
    train_mae = (train_mae / len(im_ids))
    train_rmse = (train_rmse / len(im_ids))**0.5
    return train_loss,train_mae,train_rmse

这个train函数全程没有print,但是能看到运行这段代码需要多久!原因是什么?

这个输出是由 tqdm 进度条库生成的。具体来说,pbar = tqdm(im_ids) 这一行代码创建了一个进度条对象 pbar,并在 for im_id in pbar: 循环中更新进度条的状态

使用 tqdm 进度条库后,进度条会自动在控制台中打印出当前的进度信息。你只需要确保在循环中正确地更新进度条即可

所以关键就在于:

这两句

tqdm一般配合着for循环使用,比如迭代次数,这里是图片id,感觉只要是可以迭代的对象都可以用 tqdm 包起来

再举个例子:

from tqdm import tqdm
import time

# 示例数据
data = range(100)

# 使用 tqdm 创建进度条
for item in tqdm(data):
    # 模拟一些处理时间
    time.sleep(0.1)

所以使用的时候,就把想监控的

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值