改动点:
(1)把传统的卷积改造成深度可分离卷积;
(2)使用pytorch实现的ctc,不再使用百度开源的warpctc,主要原因是本人使用Windows来开发调试,编译warpctc貌似很麻烦;
crnn网络实现代码:
class BidirectionalLSTM(nn.Module):
def __init__(self, nInput_size, nHidden,nOut):
super(BidirectionalLSTM, self).__init__()
self.lstm = nn.LSTM(nInput_size, nHidden, bidirectional=True)
self.linear = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, (hidden,cell)= self.lstm(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.linear(t_rec) # [T * b, nOut]
output = output.view(T, b, -1) #输出变换为[seq,batch,类别总数]
return output
class CNN(nn.Module):
def __init__(self,imageHeight,nChannel):
super(CNN,self).__init__()
assert imageHeight % 32 == 0,'image Height has to be a multiple of 32'
self.depth_conv0 = nn.Conv2d(in_channels=nChannel,out_channels=nChannel,kernel_size=3,stride=1,padding=1,groups=nChannel)
self.point_conv0 = nn.Conv2d(in_channels=nChannel,out_channels=64,kernel_size=1,stride=1,padding=0,groups=1)
self.relu0 = nn.ReLU(inplace=True)
self.pool0 = nn.MaxPool2d(kernel_size=2,stride=2)
self.depth_conv1 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1,groups=64)
self.point_conv1 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=1,stride=1,padding=0,groups=1)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.depth_conv2 = nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=1,padding=1,groups=128)
self.point_conv2 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=1,stride=1,padding=0,groups=1)
self.batchNorm2 = nn.BatchNorm2d(256)
self.relu2 = nn.ReLU(inplace=True)
self.depth_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=256)
self.point_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1)
self.relu3 = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,1),padding=(0,1))
self.depth_conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=256)
self.point_conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)
self.batchNorm4 = nn.BatchNorm2d(512)
self.relu4 = nn.ReLU(inplace=True)
self.depth_conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, groups=512)
self.point_conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)
self.relu5 = nn.ReLU(inplace=True)
self.pool5 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,1),padding=(0,1))
#self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=2, stride=1, padding=0)
self.depth_conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=2, stride=1, padding=0, groups=512)
self.point_conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)
self.batchNorm6 = nn.BatchNorm2d(512)
self.relu6= nn.ReLU(inplace=True)
def forward(self,input):
depth0 = self.depth_conv0(input)
point0 = self.point_conv0(depth0)
relu0 = self.relu0(point0)
pool0 = self.pool0(relu0)
# print(pool0.size())
depth1 = self.depth_conv1(pool0)
point1 = self.point_conv1(depth1)
relu1 = self.relu1(point1)
pool1 = self.pool1(relu1)
#print(pool1.size())
depth2 = self.depth_conv2(pool1)
point2 = self.point_conv2(depth2)
batchNormal2 = self.batchNorm2(point2)
relu2 = self.relu2(batchNormal2)
#print(relu2.size())
depth3 = self.depth_conv3(relu2)
point3 = self.point_conv3(depth3)
relu3 = self.relu3(point3)
pool3 = self.pool3(relu3)
#print(pool3.size())
depth4 = self.depth_conv4(pool3)
point4 = self.point_conv4(depth4)
batchNormal4 = self.batchNorm4(point4)
relu4 = self.relu4(batchNormal4)
#print(relu4.size())
depth5 = self.depth_conv5(relu4)
point5 = self.point_conv5(depth5)
relu5 = self.relu5(point5)
pool5 = self.pool5(relu5)
#print(pool5.size())
depth6 = self.depth_conv6(pool5)
point6 = self.point_conv6(depth6)
batchNormal6 = self.batchNorm6(point6)
relu6 = self.relu6(batchNormal6)
#print(relu6.size())
return relu6
class CRNN(nn.Module):
def __init__(self,imgHeight, nChannel, nClass, nHidden):
super(CRNN,self).__init__()
self.cnn = nn.Sequential(CNN(imgHeight, nChannel))
self.lstm = nn.Sequential(
BidirectionalLSTM(512, nHidden, nHidden),
BidirectionalLSTM(nHidden, nHidden, nClass),
)
def forward(self,input):
conv = self.cnn(input)
# pytorch框架输出结构为BCHW
batch,channel,height,width = conv.size()
assert height==1,"the output height must be 1."
# 将height==1的维度去掉-->BCW
conv = conv.squeeze(dim=2)
# 调整各个维度的位置(B,C,W)->(W,B,C),对应lstm的输入(seq,batch,input_size)
conv = conv.permute(2,0,1)
output = self.lstm(conv)
return output
训练网络代码:
import os
import torch
import cv2
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from crnn_new import crnn
import time
# 调整图像大小和归一化操作
class resizeAndNormalize():
def __init__(self,size,interpolation=cv2.INTER_LINEAR):
# 注意对于opencv,size的格式是(w,h)
self.size = size
self.interpolation = interpolation
# ToTensor属于类 """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
self.toTensor = transforms.ToTensor()
def __call__(self, image):
# (x,y) 对于opencv来说,图像宽对应x轴,高对应y轴
image = cv2.resize(image,self.size,interpolation=self.interpolation)
#转为tensor的数据结构
image = self.toTensor(image)
#对图像进行归一化操作
image = image.sub_(0.5).div_(0.5)
return image
class CRNNDataSet(Dataset):
def __init__(self,imageRoot,labelRoot):
self.image_root = imageRoot
self.image_dict = self.readfile(labelRoot)
self.image_name = [fileName for fileName,_ in self.image_dict.items()]
def __getitem__(self, index):
image_path = os.path.join(self.image_root,self.image_name[index])
keys = self.image_dict.get(self.image_name[index])
label = [int(x) for x in keys]
image = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE)
# if image is None:
# return None,None
(height,width) = image.shape
#由于crnn网络输入图像的高为32,故需要resize原始图像的height
size_height = 32
ratio = 32/float(height)
size_width = int(ratio * width)
transform = resizeAndNormalize((size_width,size_height))
#图像预处理
image = transform(image)
#标签格式转换为IntTensor
label = torch.IntTensor(label)
return image,label
def __len__(self):
return len(self.image_name)
def readfile(self,fileName):
res = []
with open(fileName, 'r') as f:
lines = f.readlines()
for line in lines:
res.append(line.strip())
dic = {}
total = 0
for line in res:
part = line.split(' ')
#由于会存在训练过程中取图像的时候图像不存在导致异常,所以在初始化的时候就判断图像是否存在
if not os.path.exists(os.path.join(self.image_root, part[0])):
print(os.path.join(self.image_root, part[0]))
total += 1
else:
dic[part[0]] = part[1:]
print(total)
return dic
trainData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",
labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data.txt")
trainLoader = DataLoader(dataset=trainData,batch_size=30,shuffle=True,num_workers=0)
valData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",
labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data_t.txt")
valLoader = DataLoader(dataset=valData,batch_size=1,shuffle=True,num_workers=1)
def decode(preds):
pred = []
for i in range(len(preds)):
if preds[i] != 5989 and ((i == 5989) or (i != 5989 and preds[i] != preds[i-1])):
pred.append(int(preds[i]))
return pred
def val(model, loss_function, max_iteration,use_gpu=True):
# 将模式切换为验证评估模式
model.eval()
k = 0
totalloss = 0
correct_num = 0
total_num = 0
val_iter = iter(valLoader)
max_iter = min(max_iteration,len(valLoader))
for i in range(max_iter):
k = k + 1
data,label = val_iter.next()
labels = torch.IntTensor([])
for j in range(label.size(0)):
labels = torch.cat((labels,label[j]),0)
if torch.cuda.is_available() and use_gpu:
data = data.cuda()
output = model(data)
input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))
target_lengths = torch.IntTensor([label.size(1)] * int(label.size(0)))
loss = loss_function(output,labels,input_lengths,target_lengths) / label.size(0)
totalloss += float(loss)
pred_label = output.max(2)[1]
pred_label = pred_label.transpose(1,0).contiguous().view(-1)
pred = decode(pred_label)
total_num += len(pred)
for x,y in zip(pred,labels):
if int(x) == int(y):
correct_num += 1
accuracy = correct_num / float(total_num) * 100
test_loss = totalloss / k
print('Test loss : %.3f , accuary : %.3f%%' % (test_loss, accuracy))
def train():
use_gpu = True
learning_rate = 0.0005
weight_decay = 1e-4
max_epoch = 10
modelpath = 'F:\crnn_model\pytorch-crnn.pth'
char_set = open('../train/char_std_5990.txt','r',encoding='utf-8').readlines()
char_set = ''.join([ch.strip('\n') for ch in char_set[1:]] +['卍'])
n_class = len(char_set)
model = crnn.CRNN(imgHeight=32,nChannel=1,nClass=n_class,nHidden=256)
if torch.cuda.is_available() and use_gpu:
model.cuda()
loss_func = torch.nn.CTCLoss(blank=n_class-1)
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate,weight_decay=weight_decay)
if os.path.exists(modelpath):
print("load model from %s" % modelpath)
model.load_state_dict(torch.load(modelpath))
print("done!")
lossTotal = 0.0
k = 0
printInterval = 100
valinterval = 1000
start_time = time.time()
for epoch in range(max_epoch):
for i,(data,label) in enumerate(trainLoader):
k = k + 1
#开启训练模式
model.train()
labels = torch.IntTensor([])
for j in range(label.size(0)):
labels = torch.cat((labels,label[j]),0)
if torch.cuda.is_available and use_gpu:
data = data.cuda()
loss_func = loss_func.cuda()
labels = labels.cuda()
output = model(data)
#log_probs = output
#example 建议使用这样,貌似直接把output送进去loss fun也没发现什么问题
log_probs = output.log_softmax(2).detach().requires_grad_()
targets = labels
input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))
target_lengths = torch.IntTensor([label.size(1)] * int(label.size(0)))
#forward(self, log_probs, targets, input_lengths, target_lengths)
loss = loss_func(log_probs,targets,input_lengths,target_lengths) / label.size(0)
lossTotal += float(loss)
if k % printInterval == 0:
print("[%d/%d] [%d/%d] loss:%f" % (
epoch, max_epoch, i + 1, len(trainLoader), lossTotal/printInterval))
lossTotal = 0.0
torch.save(model.state_dict(), 'F:\crnn_model\pytorch-crnn.pth')
optimizer.zero_grad()
loss.backward()
optimizer.step()
if k % valinterval == 0:
val(model,loss_func)
end_time = time.time()
print("takes {}s".format((end_time - start_time)))
if __name__ == '__main__':
train()
测试代码:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
import torch
from config import opt
from crnn import crnn
from PIL import Image
from torchvision import transforms
class resizeNormalize(object):
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
self.toTensor = transforms.ToTensor()
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img = self.toTensor(img)
img.sub_(0.5).div_(0.5)
return img
def decode(preds,char_set):
pred_text = ''
for i in range(len(preds)):
if preds[i] != 5989 and ((i == 5989) or (i != 5989 and preds[i] != preds[i-1])):
pred_text += char_set[int(preds[i])-1]
return pred_text
# test if crnn work
if __name__ == '__main__':
imagepath = './12.jpg'
img_h = opt.img_h
use_gpu = opt.use_gpu
modelpath = 'F:\crnn_model\pytorch-crnn-Copy68.pth'
#modelpath = '../train/models/pytorch-crnn.pth'
# modelpath = opt.modelpath
char_set = open('char_std_5990.txt', 'r', encoding='utf-8').readlines()
char_set = ''.join([ch.strip('\n') for ch in char_set[1:]] + ['卍'])
n_class = len(char_set)
print(n_class)
from crnn_new import crnn
model = crnn.CRNN(img_h, 1, n_class, 256)
if os.path.exists(modelpath):
print('Load model from "%s" ...' % modelpath)
model.load_state_dict(torch.load(modelpath))
print('Done!')
if torch.cuda.is_available and use_gpu:
model.cuda()
image = Image.open(imagepath).convert('L')
(w,h) = image.size
size_h = 32
ratio = size_h / float(h)
size_w = int(w * ratio)
# keep the ratio
transform = resizeNormalize((size_w, size_h))
image = transform(image)
image = image.unsqueeze(0)
if torch.cuda.is_available and use_gpu:
image = image.cuda()
model.eval()
preds = model(image)
preds = preds.max(2)
preds = preds[1]
preds = preds.squeeze()
pred_text = decode(preds,char_set)
print('predict == >',pred_text)