Pytorch|YOWO原理及代码详解(二)
本博客上接,Pytorch|YOWO原理及代码详解(一),阅前可看。
1.正式训练
if opt.evaluate:
logging('evaluating ...')
test(0)
else:
for epoch in range(opt.begin_epoch, opt.end_epoch + 1):
# Train the model for 1 epoch
train(epoch)
# Validate the model
fscore = test(epoch)
is_best = fscore > best_fscore
if is_best:
print("New best fscore is achieved: ", fscore)
print("Previous fscore was: ", best_fscore)
best_fscore = fscore
# Save the model to backup directory
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'fscore': fscore
}
save_checkpoint(state, is_best, backupdir, opt.dataset, clip_duration)
logging('Weights are saved to backup directory: %s' % (backupdir))
为了训练,设置opt.evaluate = False
,根据论文(YOWO翻译)(可知ucf24训练5个epoch就可以了。
2. train
查看整个train函数。
def train(epoch):
global processed_batches
t0 = time.time()
cur_model = model.module
region_loss.l_x.reset()
region_loss.l_y.reset()
region_loss.l_w.reset()
region_loss.l_h.reset()
region_loss.l_conf.reset()
region_loss.l_cls.reset()
region_loss.l_total.reset()
train_loader = torch.utils.data.DataLoader(
dataset.listDataset(basepath, trainlist, dataset_use=dataset_use, shape=(init_width, init_height),
shuffle=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
train=True,
seen=cur_model.seen,
batch_size=batch_size,
clip_duration=clip_duration,
num_workers=num_workers),
batch_size=batch_size, shuffle=False, **kwargs)
lr = adjust_learning_rate(optimizer, processed_batches)
logging('training at epoch %d, lr %f' % (epoch, lr))
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
adjust_learning_rate(optimizer, processed_batches)
processed_batches = processed_batches + 1
if use_cuda:
data = data.cuda()
optimizer.zero_grad()
output = model(data)
region_loss.seen = region_loss.seen + data.data.size(0)
loss = region_loss(output, target)
loss.backward()
optimizer.step()
# save result every 1000 batches
if processed_batches % 500 == 0: # From time to time, reset averagemeters to see improvements
region_loss.l_x.reset()
region_loss.l_y.reset()
region_loss.l_w.reset()
region_loss.l_h.reset()
region_loss.l_conf.reset()
region_loss.l_cls.reset()
region_loss.l_total.reset()
t1 = time.time()
logging('trained with %f samples/s' % (len(train_loader.dataset) / (t1 - t0)))
print('')
processed_batches是全局变量,存储已经处理的batch数,方便断点继续训练。t0 = time.time()
记录当前的时间。region_loss初始化。
2.1 加载训练数据集
训练数据集是放在listDataset类中。listDataset是在dataset.py中,完整代码如下:
class listDataset(Dataset):
# clip duration = 8, i.e, for each time 8 frames are considered together
def __init__(self, base, root, dataset_use='ucf101-24', shape=None, shuffle=True,
transform=None, target_transform=None,
train=False, seen=0, batch_size=64,
clip_duration=16, num_workers=4):
with open(root, 'r') as file:
self.lines = file.readlines()
if shuffle:
random.shuffle(self.lines)
self.base_path = base
self.dataset_use = dataset_use
self.nSamples = len(self.lines)
self.transform = transform
self.target_transform = target_transform
self.train = train
self.shape = shape
self.seen = seen
self.batch_size = batch_size
self.clip_duration = clip_duration
self.num_workers = num_workers
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
imgpath = self.lines[index].rstrip()
self.shape = (224, 224)
if self.train: # For Training
jitter = 0.2
hue = 0.1
saturation = 1.5
exposure = 1.5
clip, label = load_data_detection(self.base_path, imgpath, self.train, self.clip_duration, self.shape, self.dataset_use, jitter, hue, saturation, exposure)
else: # For Testing
frame_idx, clip, label = load_data_detection(self.base_path, imgpath, False, self.clip_duration, self.shape, self.dataset_use)
clip = [img.resize(self.shape) for img in clip]
if self.transform is not None:
clip = [self.transform(img) for img in clip]
# (self.duration, -1) + self.shape = (8, -1, 224, 224)
clip = torch.cat(clip, 0).view((self.clip_duration, -1) + self.shape).permute(1, 0, 2, 3)
if self.target_transform is not None:
label = self.target_transform(label)
self.seen = self.seen + self.num_workers
if self.train:
return (clip, label)
else:
return (frame_idx, clip, label)
self.lines
存储读去trainlist.txt的文本内容:
random.shuffle(self.lines)
是对其进行打乱。剩下的就是一顿初始化:
2.2 学习率调整
lr = adjust_learning_rate(optimizer, processed_batches)
完整代码:
def adjust_learning_rate(optimizer, batch):
lr = learning_rate
for i in range(len(steps)):
scale = scales[i] if i < len(scales) else 1
if batch >= steps[i]:
lr = lr * scale
if batch == steps[i]:
break
else:
break
for param_group in optimizer.param_groups:
param_group['lr'] = lr / batch_size
return lr
学习率是根据steps进行不断调整的,如下:
......
lr = adjust_learning_rate(optimizer, processed_batches)
logging('training at epoch %d, lr %f' % (epoch, lr))
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
adjust_learning_rate(optimizer, processed_batches)
processed_batches = processed_batches + 1
......
scales和step是ucf24.cfg中进行设置的衰减策略,如下:
2.3 获取训练数据
这段for batch_idx, (data, target) in enumerate(train_loader):
中的data, target是通过listDataset在的def __getitem__(self, index):
进行获取的。
其中:
jitter = 0.2
hue = 0.1
saturation = 1.5
exposure = 1.5
是yolov2中用于数据增强的数据,参考YOLOv2 参数详解,其中各项参数意义如下:
- jitter:利用数据抖动产生更多数据
- hue:色调变化范围
- saturation & exposure: 饱和度与曝光变化大小
这部分最关键的是:load_data_detection(self.base_path, imgpath, self.train, self.clip_duration, self.shape, self.dataset_use, jitter, hue, saturation, exposure)
完整代码如下:
def load_data_detection(base_path, imgpath, train, train_dur, shape, dataset_use='ucf101-24', jitter=0.2, hue=0.1, saturation=1.5, exposure=1.5):
# clip loading and data augmentation
# if dataset_use == 'ucf101-24':
# base_path = "/usr/home/sut/datasets/ucf24"
# else:
# base_path = "/usr/home/sut/Tim-Documents/jhmdb/data/jhmdb"
im_split = imgpath.split('/')
num_parts = len(im_split)
im_ind = int(im_split[num_parts-1][0:5])
labpath = os.path.join(base_path, 'labels', im_split[0], im_split[1] ,'{:05d}.txt'.format(im_ind))
img_folder = os.path.join(base_path, 'rgb-images', im_split[0], im_split[1])
if dataset_use == 'ucf101-24':
max_num = len(os.listdir(img_folder))
else:
max_num = len(os.listdir(img_folder)) - 1
clip = []
### We change downsampling rate throughout training as a ###
### temporal augmentation, which brings around 1-2 frame ###
### mAP. During test time it is set to 1. ###
d = 1
if train:
d = random.randint(1, 2)
for i in reversed(range(train_dur)):
# make it as a loop
i_temp = im_ind - i * d
while i_temp < 1:
i_temp = max_num + i_temp
while i_temp > max_num:
i_temp = i_temp - max_num
if dataset_use == 'ucf101-24':
path_tmp = os.path.join(base_path, 'rgb-images', im_split[0], im_split[1] ,'{:05d}.jpg'.format(i_temp))
else:
path_tmp = os.path.join(base_path, 'rgb-images', im_split[0], im_split[1] ,'{:05d}.png'.format(i_temp))
clip.append(Image.open(path_tmp).convert('RGB'))
if train: # Apply augmentation
clip,flip,dx,dy,sx,sy = data_augmentation(clip, shape, jitter, hue, saturation, exposure)
label = fill_truth_detection(labpath, clip[0].width, clip[0].height, flip, dx, dy, 1./sx, 1./sy)
label = torch.from_numpy(label)
else: # No augmentation
label = torch.zeros(50*5)
try:
tmp = torch.from_numpy(read_truths_args(labpath, 8.0/clip[0].width).astype('float32'))
except Exception:
tmp = torch.zeros(1,5)
tmp = tmp.view(-1)
tsz = tmp.numel()
if tsz > 50*5:
label = tmp[0:50*5]
elif tsz > 0:
label[0:tsz] = tmp
if train:
return clip, label
else:
return im_split[0] + '_' +im_split[1] + '_' + im_split[2], clip, label
通过把路径进行分割:im_split = imgpath.split('/')
,来找到标注labpath和对应的文件夹img_folder。
“在整个训练过程中,改变下采样率作为一个时间增量,得到1-2帧左右的图像。在测试期间,它被设置为1。”这个下采样率是指帧与帧之间的采样距离,如果d=2,则没隔两帧读取数据,依此类推。
d = 1
if train:
d = random.randint(1, 2)
这个train_dur对应的参数是self.clip_duration,剪辑持续时间,默认设置的是16。im_ind(在上述图片中可以看到,是56)是标注是整个视频(图像,视频被切割成一张张的图像)序列中的ID。
i_temp = im_ind - i * d
while i_temp < 1:
i_temp = max_num + i_temp
while i_temp > max_num:
i_temp = i_temp - max_num
上述代码是为了现在i_temp有效,如果溢出,则使用循环序列。
path_tmp则是获取对应帧的图像,如下:
clip.append(Image.open(path_tmp).convert('RGB'))
将其转换成RGB模式,添加到序列clip中。
在训练的过程中还会使用数据增强来扩充数据集:
if train: # Apply augmentation
clip,flip,dx,dy,sx,sy = data_augmentation(clip, shape, jitter, hue, saturation, exposure)
label = fill_truth_detection(labpath, clip[0].width, clip[0].height, flip, dx, dy, 1./sx, 1./sy)
label = torch.from_numpy(label)
data_augmentation完整代码如下:
def data_augmentation(clip, shape, jitter, hue, saturation, exposure):
# Initialize Random Variables
oh = clip[0].height
ow = clip[0].width
dw =int(ow*jitter)
dh =int(oh*jitter)
pleft = random.randint(-dw, dw)
pright = random.randint(-dw, dw)
ptop = random.randint(-dh, dh)
pbot = random.randint(-dh, dh)
swidth = ow - pleft - pright
sheight = oh - ptop - pbot
sx = float(swidth) / ow
sy = float(sheight) / oh
dx = (float(pleft)/ow)/sx
dy = (float(ptop) /oh)/sy
flip = random.randint(1,10000)%2
dhue = random.uniform(-hue, hue)
dsat = rand_scale(saturation)
dexp = rand_scale(exposure)
# Augment
cropped = [img.crop((pleft, ptop, pleft + swidth - 1, ptop + sheight - 1)) for img in clip]
sized = [img.resize(shape) for img in cropped]
if flip:
sized = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in sized]
clip = [random_distort_image(img, dhue, dsat, dexp) for img in sized]
return clip, flip, dx, dy, sx, sy
关于数据增强,这里有几个参数是yolov2中的,上述已经讲过,即是对图像进行抖动,改变色调、饱和度以及曝光度,并进行尺度归一化,缩放为 224 × 224 224 \times 224 224×224。
pleft和ptop是靠左(向右),靠上(向下)的偏移量,swidth和sheight是数据抖动后的宽和高,通过这些参数进行裁剪,cropped = [img.crop((pleft, ptop, pleft + swidth - 1, ptop + sheight - 1)) for img in clip]
。
并进一步尺度归一化:sized = [img.resize(shape) for img in cropped]
。
如果标志位flip为true,则还会进行水平翻转:sized = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in sized]
。
由于图像以及增强了,那么对应的标注label也需要进行对应的改变:label = fill_truth_detection(labpath, clip[0].width, clip[0].height, flip, dx, dy, 1./sx, 1./sy)
。由于图像增强,主要是图像的偏移以及方式,所以对应标签的变化需要知道图像是否水平翻转,以及图像的偏移量和放缩量,即flip, dx, dy, 1./sx, 1./sy
。查看完整代码:
def fill_truth_detection(labpath, w, h, flip, dx, dy, sx, sy):
max_boxes = 50
label = np.zeros((max_boxes,5))
if os.path.getsize(labpath):
bs = np.loadtxt(labpath)
if bs is None:
return label
bs = np.reshape(bs, (-1, 5))
for i in range(bs.shape[0]):
cx = (bs[i][1] + bs[i][3]) / (2 * 320)
cy = (bs[i][2] + bs[i][4]) / (2 * 240)
imgw = (bs[i][3] - bs[i][1]) / 320
imgh = (bs[i][4] - bs[i][2]) / 240
bs[i][0] = bs[