接上:Datawhale-天池入门赛街景字符编码识别-Task1:赛题理解、Datawhale-天池入门赛街景字符编码识别-Task2:数据读取与数据增强、Datawhale-天池入门赛街景字符编码识别-Task3:字符识别模型
近期进展
近期开始尝试使用检测模型,首先就要对label进行适当处理,自己尝试通过部分可视化的手段,帮助自己对坐标框进行处理,代码如下:
class Visualization:
def __init__(self, image_path, label_path):
self.image_path = image_path
self.label_path = label_path
self.image_name = [i.split('\\')[-1] for i in self.image_path]
def show_box(self,save_path):
save_path = save_path
if not os.path.exists(save_path):
os.mkdir(save_path)
image_path = self.image_path
label_path = self.label_path
image_name = self.image_name
for i in range(len(image_path)):
image = cv2.imread(image_path[i])
for j in range(len(label_path[image_name[i]]['label'])):
left = label_path[image_name[i]]['left'][j]
top = label_path[image_name[i]]['top'][j]
height = label_path[image_name[i]]['height'][j]
width = label_path[image_name[i]]['width'][j]
cv2.rectangle(image, (int(left), int(top)), (int(left+width), int(top+height)), (0, 0, 255), 1)
cv2.imwrite(os.path.join(save_path, train_name[i]), image)
def show_max_box(self,save_path):
save_path = save_path
if not os.path.exists(save_path):
os.mkdir(save_path)
image_path = self.image_path
label_path = self.label_path
image_name = self.image_name
for i in range(len(image_path)):
image = cv2.imread(image_path[i])
x1 = min(label_path[image_name[i]]['left'])
y1 = min(label_path[image_name[i]]['top'])
x2 = label_path[image_name[i]]['left'][-1] + label_path[image_name[i]]['width'][-1]
y2 = max(label_path[image_name[i]]['top']) + max(label_path[image_name[i]]['height'])
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 1)
cv2.imwrite(os.path.join(save_path, train_name[i]), image)
def count_label(self):
label_path = self.label_path
image_name = self.image_name
label_len = [len(label_path[image_name[i]]['label']) for i in range(len(label_path))]
for i in range(6):
print('{} {} {:' '<6d} {:.6f}'.format(i+1,':',label_len.count(i+1), label_len.count(i+1)/len(label_path)))
plt.hist(label_len)
def show_outlier(self):
image_path = self.image_path
label_path = self.label_path
image_name = self.image_name
for i in range(len(label_path)):
if len(label_path[image_name[i]]['label']) in [5,6]:
print(image_name[i])
def show_image_size(self):
image_path = self.image_path
width = [cv2.imread(image_path[i]).shape[0] for i in range(len(image_path))]
height = [cv2.imread(image_path[i]).shape[0] for i in range(len(image_path))]
print(sum(width) / len(width), sum(height) / len(height))
plt.scatter(width,height)
通过以上代码,可以可视化每个数字的坐标框。
也可以将所有数字的坐标框合并,做序列预测。
此外,通过一定的统计可以发现,数字位数集中在1-4,5和6位的数字很少,甚至可以当作离群点处理。
另外,还可以发现,所有图像的宽高比基本一致,分布比较散。
另外,还对模型的训练代码做了一些补充。
def train(train_loader, model, criterion, optimizer):
# 切换模型为训练模式
model.train()
train_loss = []
for i, (input, target) in enumerate(train_loader):
input = input.cuda()
target = target.cuda().long()
c1, c2, c3, c4 ,c5 = model(input)
loss = criterion(c1, target[:, 0]) + \
criterion(c2, target[:, 1]) + \
criterion(c3, target[:, 2]) + \
criterion(c4, target[:, 3]) + \
criterion(c5, target[:, 4])
# 梯度累加
accumulation_steps = 10
# loss = loss/accumulation_steps
loss.backward()
if((i+1)%accumulation_steps)==0:
optimizer.step()
optimizer.zero_grad()
# lr_scheduler.step(epoch)
train_loss.append(loss.item())
return np.mean(train_loss)
上述代码用到了梯度累加的技巧,可以变相地提升batchsize。
def set_random_seed(seed = 0,deterministic=False,benchmark=True):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
if benchmark:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
上述代码可以固定一些随机种子,提升代码的可复现性。
RESUME = False
if RESUME:
path_checkpoint = r"result\0528_135935\checkpoints\ckpt_epoch150.pth" # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
model = SVHN_Model1().cuda()
model.load_state_dict(checkpoint['net']) # 加载模型可学习参数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch
# lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
else:
model = SVHN_Model1().cuda()
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=5e-4)
# lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 120], gamma=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
上述代码可以从断点恢复模型训练,当然,还要配合以下代码。
# 保存断点
checkpoint = {
"net": model.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch,
# 'lr_schedule': lr_scheduler.state_dict()
}
checkpoint_path = time_path + '/checkpoints'
if (epoch+1) %5 ==0:
if not os.path.exists(checkpoint_path):
os.mkdir(checkpoint_path)
torch.save(checkpoint, checkpoint_path + '/ckpt_epoch%s.pth' % (str(epoch+1)))
print("Save checkpoint at epoch:", epoch+1)
下一步计划
目前还是没有成功尝试新的模型,不过打算以这次机会,熟悉全套字符识别流程,并掌握常用的Pytorch方法和实用的Pytorch技巧,希望近期内能实现一个检测加识别的新baseline,将分数刷到0.8+。