转验证码识别概述
这次我将带着大家实现旋转验证码的识别,当然当前时间点上识别教程有很多,我的教程会有自己优化的点(主要在损失函数上),使准确率更加高,需要的数据集更加少。
项目开源地址
https://github.com/2833844911/Rotate-the-verification-code-pytorch/
定义模型
下面我们使用的是resnet50模型进行分类
class CustomResNet50(nn.Module):
def __init__(self, num_classes=num_classes):
super(CustomResNet50, self).__init__()
# 加载ResNet50模型
resnet50 = models.resnet50(pretrained=True)
self.resnet_layers = nn.Sequential(*list(resnet50.children())[:-1])
self.flatten = nn.Flatten()
self.fc = nn.Linear(2048, 360) #把结果输出为360个类别
self.sfm = nn.Sigmoid()
def forward(self, x):
x = self.resnet_layers(x)
x = self.flatten(x)
x = self.fc(x)
x = self.sfm(x)
return x
损失函数实现
看网上教程他们用的损失函数大部分就是直接使用交叉熵损失函数,但这样子的效果其实不是很好,会导致模型偏差太大
使用交叉熵损失函数最优输出【假设目标角度为90度】
当使用交叉熵损失函数时 需要达到效果最好时目标角度输出数据应当为1,其他角度都要为0,这样子才可以让损失最小,但是这样子我们可以发现角度93的损失和角度91的损失所带来的效果是一样的,其实我们应该最大当预测角度为91时其实损失应该更加小,那要怎么设计损失函数呢?
接下来我们开始来实现
首先我们也画一张图来表示我们是最佳输出
使用自定义函数最优输出【假设目标角度为90度】
由图可以看出我们输出的结果应当为
88度为0.8
89度为0.9
90度为1
91度为0.9
92度为0.8
93度为0.7
这样子我们就可以体现出91度和93度的区别,这样子当模型预测角度不为90度,那也会有更大的可能会预测到89度或者91度
下面是实现的代码
class CustomClassificationLoss(nn.Module):
def __init__(self, num_classes, alpha=18):
super(CustomClassificationLoss, self).__init__()
def forward(self, logits, targets):
ce_loss = 0
for i in range(logits.shape[0]):
wz = targets[i].item()
kj = wz
lossOne = (1-logits[i][wz])
for r in range(1, 180):
wz -= 1
if wz <0:
wz = 359
lossOne += torch.abs((0.98 ** r) - logits[i][wz])
for r in range(1, 180):
kj += 1
if kj >359:
kj = 0
lossOne += torch.abs((0.98 ** r) - logits[i][kj])
ce_loss += lossOne
loss = ce_loss/targets.shape[0]
return loss, 0, loss.item()
训练好的模型下载地址
链接:https://pan.baidu.com/s/1mcz58nquhL0DA8tCX5f_ww
提取码:39ev
需要配合我github的代码使用
星球 https://t.zsxq.com/0fLP2ulsk 【有魔改node, 魔改浏览器,等分析】