Pytorch|YOWO原理及代码详解(二)

Pytorch|YOWO原理及代码详解(二)

本博客上接,Pytorch|YOWO原理及代码详解(一),阅前可看。

1.正式训练

    if opt.evaluate:
        logging('evaluating ...')
        test(0)
    else:
        for epoch in range(opt.begin_epoch, opt.end_epoch + 1):
            # Train the model for 1 epoch
            train(epoch)

            # Validate the model
            fscore = test(epoch)

            is_best = fscore > best_fscore
            if is_best:
                print("New best fscore is achieved: ", fscore)
                print("Previous fscore was: ", best_fscore)
                best_fscore = fscore

            # Save the model to backup directory
            state = {
   
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'fscore': fscore
            }
            save_checkpoint(state, is_best, backupdir, opt.dataset, clip_duration)
            logging('Weights are saved to backup directory: %s' % (backupdir))

为了训练,设置opt.evaluate = False,根据论文(YOWO翻译)(可知ucf24训练5个epoch就可以了。

2. train

查看整个train函数。

    def train(epoch):
        global processed_batches
        t0 = time.time()
        cur_model = model.module
        region_loss.l_x.reset()
        region_loss.l_y.reset()
        region_loss.l_w.reset()
        region_loss.l_h.reset()
        region_loss.l_conf.reset()
        region_loss.l_cls.reset()
        region_loss.l_total.reset()
        train_loader = torch.utils.data.DataLoader(
            dataset.listDataset(basepath, trainlist, dataset_use=dataset_use, shape=(init_width, init_height),
                                shuffle=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                ]),
                                train=True,
                                seen=cur_model.seen,
                                batch_size=batch_size,
                                clip_duration=clip_duration,
                                num_workers=num_workers),
            batch_size=batch_size, shuffle=False, **kwargs)

        lr = adjust_learning_rate(optimizer, processed_batches)
        logging('training at epoch %d, lr %f' % (epoch, lr))

        model.train()

        for batch_idx, (data, target) in enumerate(train_loader):
            adjust_learning_rate(optimizer, processed_batches)
            processed_batches = processed_batches + 1

            if use_cuda:
                data = data.cuda()

            optimizer.zero_grad()
            output = model(data)
            region_loss.seen = region_loss.seen + data.data.size(0)
            loss = region_loss(output, target)
            loss.backward()
            optimizer.step()

            # save result every 1000 batches
            if processed_batches % 500 == 0:  # From time to time, reset averagemeters to see improvements
                region_loss.l_x.reset()
                region_loss.l_y.reset()
                region_loss.l_w.reset()
                region_loss.l_h.reset()
                region_loss.l_conf.reset()
                region_loss.l_cls.reset()
                region_loss.l_total.reset()

        t1 = time.time()
        logging('trained with %f samples/s' % (len(train_loader.dataset) / (t1 - t0)))
        print('')

processed_batches是全局变量,存储已经处理的batch数,方便断点继续训练。t0 = time.time()记录当前的时间。region_loss初始化。

2.1 加载训练数据集

训练数据集是放在listDataset类中。listDataset是在dataset.py中,完整代码如下:

class listDataset(Dataset):
    # clip duration = 8, i.e, for each time 8 frames are considered together
    def __init__(self, base, root, dataset_use='ucf101-24', shape=None, shuffle=True,
                 transform=None, target_transform=None, 
                 train=False, seen=0, batch_size=64,
                 clip_duration=16, num_workers=4):
        with open(root, 'r') as file:
            self.lines = file.readlines()
        if shuffle:
            random.shuffle(self.lines)
        self.base_path = base
        self.dataset_use = dataset_use
        self.nSamples  = len(self.lines)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.shape = shape
        self.seen = seen
        self.batch_size = batch_size
        self.clip_duration = clip_duration
        self.num_workers = num_workers
    def __len__(self):
        return self.nSamples
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()
        self.shape = (224, 224)
        if self.train: # For Training
            jitter = 0.2
            hue = 0.1
            saturation = 1.5 
            exposure = 1.5
            clip, label = load_data_detection(self.base_path, imgpath,  self.train, self.clip_duration, self.shape, self.dataset_use, jitter, hue, saturation, exposure)
        else: # For Testing
            frame_idx, clip, label = load_data_detection(self.base_path, imgpath, False, self.clip_duration, self.shape, self.dataset_use)
            clip = [img.resize(self.shape) for img in clip]
        if self.transform is not None:
            clip = [self.transform(img) for img in clip]
        # (self.duration, -1) + self.shape = (8, -1, 224, 224)
        clip = torch.cat(clip, 0).view((self.clip_duration, -1) + self.shape).permute(1, 0, 2, 3)
        if self.target_transform is not None:
            label = self.target_transform(label)
        self.seen = self.seen + self.num_workers
        if self.train:
            return (clip, label)
        else:
            return (frame_idx, clip, label)

self.lines存储读去trainlist.txt的文本内容:
在这里插入图片描述
random.shuffle(self.lines)是对其进行打乱。剩下的就是一顿初始化:
在这里插入图片描述

2.2 学习率调整

lr = adjust_learning_rate(optimizer, processed_batches)

完整代码:

    def adjust_learning_rate(optimizer, batch):
        lr = learning_rate
        for i in range(len(steps)):
            scale = scales[i] if i < len(scales) else 1
            if batch >= steps[i]:
                lr = lr * scale
                if batch == steps[i]:
                    break
            else:
                break
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr / batch_size
        return lr

学习率是根据steps进行不断调整的,如下:

        ......
        lr = adjust_learning_rate(optimizer, processed_batches)
        logging('training at epoch %d, lr %f' % (epoch, lr))
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            adjust_learning_rate(optimizer, processed_batches)
            processed_batches = processed_batches + 1
        ......

scalesstepucf24.cfg中进行设置的衰减策略,如下:

在这里插入图片描述

2.3 获取训练数据

这段for batch_idx, (data, target) in enumerate(train_loader):中的data, target是通过listDataset在的def __getitem__(self, index):进行获取的。
在这里插入图片描述
其中:

            jitter = 0.2
            hue = 0.1
            saturation = 1.5 
            exposure = 1.5

是yolov2中用于数据增强的数据,参考YOLOv2 参数详解,其中各项参数意义如下:

  • jitter:利用数据抖动产生更多数据
  • hue:色调变化范围
  • saturation & exposure: 饱和度与曝光变化大小

这部分最关键的是:load_data_detection(self.base_path, imgpath, self.train, self.clip_duration, self.shape, self.dataset_use, jitter, hue, saturation, exposure)
完整代码如下:

def load_data_detection(base_path, imgpath, train, train_dur, shape, dataset_use='ucf101-24', jitter=0.2, hue=0.1, saturation=1.5, exposure=1.5):
    # clip loading and  data augmentation
    # if dataset_use == 'ucf101-24':
    #     base_path = "/usr/home/sut/datasets/ucf24"
    # else:
    #     base_path = "/usr/home/sut/Tim-Documents/jhmdb/data/jhmdb"
    im_split = imgpath.split('/')
    num_parts = len(im_split)
    im_ind = int(im_split[num_parts-1][0:5])
    labpath = os.path.join(base_path, 'labels', im_split[0], im_split[1] ,'{:05d}.txt'.format(im_ind))
    img_folder = os.path.join(base_path, 'rgb-images', im_split[0], im_split[1])
    if dataset_use == 'ucf101-24':
        max_num = len(os.listdir(img_folder))
    else:
        max_num = len(os.listdir(img_folder)) - 1
    clip = []
    ### We change downsampling rate throughout training as a ###
    ### temporal augmentation, which brings around 1-2 frame ###
    ### mAP. During test time it is set to 1.                ###
    d = 1 
    if train:
        d = random.randint(1, 2)
    for i in reversed(range(train_dur)):
        # make it as a loop
        i_temp = im_ind - i * d
        while i_temp < 1:
            i_temp = max_num + i_temp
        while i_temp > max_num:
            i_temp = i_temp - max_num
        if dataset_use == 'ucf101-24':
            path_tmp = os.path.join(base_path, 'rgb-images', im_split[0], im_split[1] ,'{:05d}.jpg'.format(i_temp))
        else:
            path_tmp = os.path.join(base_path, 'rgb-images', im_split[0], im_split[1] ,'{:05d}.png'.format(i_temp))
        clip.append(Image.open(path_tmp).convert('RGB'))
    if train: # Apply augmentation
        clip,flip,dx,dy,sx,sy = data_augmentation(clip, shape, jitter, hue, saturation, exposure)
        label = fill_truth_detection(labpath, clip[0].width, clip[0].height, flip, dx, dy, 1./sx, 1./sy)
        label = torch.from_numpy(label)
    else: # No augmentation
        label = torch.zeros(50*5)
        try:
            tmp = torch.from_numpy(read_truths_args(labpath, 8.0/clip[0].width).astype('float32'))
        except Exception:
            tmp = torch.zeros(1,5)
        tmp = tmp.view(-1)
        tsz = tmp.numel()
        if tsz > 50*5:
            label = tmp[0:50*5]
        elif tsz > 0:
            label[0:tsz] = tmp
    if train:
        return clip, label
    else:
        return im_split[0] + '_' +im_split[1] + '_' + im_split[2], clip, label

通过把路径进行分割:im_split = imgpath.split('/'),来找到标注labpath和对应的文件夹img_folder
在这里插入图片描述
“在整个训练过程中,改变下采样率作为一个时间增量,得到1-2帧左右的图像。在测试期间,它被设置为1。”这个下采样率是指帧与帧之间的采样距离,如果d=2,则没隔两帧读取数据,依此类推。

    d = 1 
    if train:
        d = random.randint(1, 2)

这个train_dur对应的参数是self.clip_duration,剪辑持续时间,默认设置的是16。im_ind(在上述图片中可以看到,是56)是标注是整个视频(图像,视频被切割成一张张的图像)序列中的ID。

        i_temp = im_ind - i * d
        while i_temp < 1:
            i_temp = max_num + i_temp
        while i_temp > max_num:
            i_temp = i_temp - max_num

上述代码是为了现在i_temp有效,如果溢出,则使用循环序列。
path_tmp则是获取对应帧的图像,如下:
在这里插入图片描述
clip.append(Image.open(path_tmp).convert('RGB'))将其转换成RGB模式,添加到序列clip中。
在训练的过程中还会使用数据增强来扩充数据集:

    if train: # Apply augmentation
        clip,flip,dx,dy,sx,sy = data_augmentation(clip, shape, jitter, hue, saturation, exposure)
        label = fill_truth_detection(labpath, clip[0].width, clip[0].height, flip, dx, dy, 1./sx, 1./sy)
        label = torch.from_numpy(label)

data_augmentation完整代码如下:

def data_augmentation(clip, shape, jitter, hue, saturation, exposure):
    # Initialize Random Variables
    oh = clip[0].height  
    ow = clip[0].width
    dw =int(ow*jitter)
    dh =int(oh*jitter)
    pleft  = random.randint(-dw, dw)
    pright = random.randint(-dw, dw)
    ptop   = random.randint(-dh, dh)
    pbot   = random.randint(-dh, dh)
    swidth =  ow - pleft - pright
    sheight = oh - ptop - pbot
    sx = float(swidth)  / ow
    sy = float(sheight) / oh 
    dx = (float(pleft)/ow)/sx
    dy = (float(ptop) /oh)/sy
    flip = random.randint(1,10000)%2
    dhue = random.uniform(-hue, hue)
    dsat = rand_scale(saturation)
    dexp = rand_scale(exposure)
    # Augment
    cropped = [img.crop((pleft, ptop, pleft + swidth - 1, ptop + sheight - 1)) for img in clip]
    sized = [img.resize(shape) for img in cropped]
    if flip: 
        sized = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in sized]
    clip = [random_distort_image(img, dhue, dsat, dexp) for img in sized]
    return clip, flip, dx, dy, sx, sy 

关于数据增强,这里有几个参数是yolov2中的,上述已经讲过,即是对图像进行抖动,改变色调、饱和度以及曝光度,并进行尺度归一化,缩放为 224 × 224 224 \times 224 224×224
pleftptop是靠左(向右),靠上(向下)的偏移量,swidthsheight是数据抖动后的宽和高,通过这些参数进行裁剪,cropped = [img.crop((pleft, ptop, pleft + swidth - 1, ptop + sheight - 1)) for img in clip]
并进一步尺度归一化:sized = [img.resize(shape) for img in cropped]
如果标志位flip为true,则还会进行水平翻转:sized = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in sized]
由于图像以及增强了,那么对应的标注label也需要进行对应的改变:label = fill_truth_detection(labpath, clip[0].width, clip[0].height, flip, dx, dy, 1./sx, 1./sy)。由于图像增强,主要是图像的偏移以及方式,所以对应标签的变化需要知道图像是否水平翻转,以及图像的偏移量和放缩量,即flip, dx, dy, 1./sx, 1./sy。查看完整代码:

def fill_truth_detection(labpath, w, h, flip, dx, dy, sx, sy):
    max_boxes = 50
    label = np.zeros((max_boxes,5))
    if os.path.getsize(labpath):
        bs = np.loadtxt(labpath)
        if bs is None:
            return label
        bs = np.reshape(bs, (-1, 5))
        for i in range(bs.shape[0]):
            cx = (bs[i][1] + bs[i][3]) / (2 * 320)
            cy = (bs[i][2] + bs[i][4]) / (2 * 240)
            imgw = (bs[i][3] - bs[i][1]) / 320
            imgh = (bs[i][4] - bs[i][2]) / 240
            bs[i][0] = bs[
评论 27
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值