pytorch不定长文本识别resnet18+LSTM
import torch
from torch import nn
from torch.nn import LSTM, Linear
from torchvision import models
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
import os
import cv2
import numpy as np
from tqdm import tqdm
import torchvision.transforms as T
// 图片大小
IMAGE_SHAPE = (28, 135)
transform = T.Compose([
T.ToPILImage(),
T.Resize(IMAGE_SHAPE),
T.ToTensor()
])
// 标签'_'代表占位,不定长必要
LABEL_MAP = [i for i in '_0123456789-+=']
Max_label_len = 6
class MyDataset(Dataset):
def __init__(self, data_path, label_map, max_label_len):
super(MyDataset, self).__init__()
self.data = [(os.path.join(data_path, file), file.split('_')[0]) for file in os.listdir(data_path)]
self.label_map = [char for char in label_map]
self.label_map_len = len(self.label_map)
self.max_label_len = max_label_len
def __getitem__(self, index):
file = self.data[index][0]
label = self.data[index][1]
raw_len = len(label)
im = np.fromfile(file, dtype=np.uint8)
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
im = transform(im)
label = [self.label_map.index(i) for i in label]
for i in range(self.max_label_len - len(label)):
label.append(0)
label = np.asarray(label, dtype='int32').reshape(self.max_label_len)
return im, label, raw_len
def __len__(self):
return len(self.data)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.resnet18 = nn.Sequential(*list(models.resnet18().children())[0:-3])
bone_output_shape = self._cal_shape()
self.lstm = LSTM(bone_output_shape, bone_output_shape, num_layers=1, bidirectional=True)
self.linear = Linear(bone_output_shape * 2, 256)
self.lstm1 = LSTM(256, bone_output_shape, num_layers=1, bidirectional=True)
self.linear1 = Linear(bone_output_shape * 2, len(LABEL_MAP))
def _cal_shape(self):
x = torch.zeros((1, 3) + IMAGE_SHAPE)
shape = self.resnet18(x).shape
return shape[1] * shape[2]
def forward(self, x):
x = self.resnet18(x)
x = x.permute(3, 0, 1, 2)
w, b, c, h = x.shape
x = x.view(w, b, c * h)
x, _ = self.lstm(x)
time_step, batch_size, h = x.shape
x = x.view(time_step * batch_size, h)
x = self.linear(x)
x = x.view(time_step, batch_size, -1)
x, _ = self.lstm1(x)
time_step, batch_size, h = x.shape
x = x.view(time_step * batch_size, h)
x = self.linear1(x)
x = x.view(time_step, batch_size, -1)
return x
def tranfromlabel(label):
t_label = []
for i in label:
t_label.append(LABEL_MAP[i])
return ''.join(t_label)
def ctc_to_str(data):
"""
CTC 解码
:param data: 编码后的文本
:param label_map: 码表
:return: 解码后文本
"""
result = []
last = -1
for i in list(data):
if i == 0:
last = -1
elif i != last:
result.append(i)
last = i
return tranfromlabel(result)
train = DataLoader(
dataset=MyDataset(r'./train', label_map=LABEL_MAP, max_label_len=Max_label_len),
batch_size=32, shuffle=True,
num_workers=3)
test = DataLoader(
dataset=MyDataset(r'./test', label_map=LABEL_MAP, max_label_len=Max_label_len),
batch_size=4, shuffle=True,
num_workers=0)
if __name__ == '__main__':
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net()
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CTCLoss()
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=3)
for epoch in range(0, 100):
bar = tqdm(train, 'Training')
for images, labels, target_lengths in bar:
images = images.to(DEVICE)
predict = model(images)
predict_lengths = torch.IntTensor([[int(predict.shape[0])] * labels.shape[0]])
loss = loss_func(predict, labels, predict_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr = optimizer.param_groups[0]['lr']
bar.set_description("Train epoch %d, loss %.4f, lr %.6f" % (
epoch, loss.detach().cpu().numpy(), lr
))
bar = tqdm(test, 'Validating')
correct = count = 0
for images, labels, target_lengths in bar:
images = images.to(DEVICE)
predicts = model(images)
for i in range(predicts.shape[1]):
predict = predicts[:, i, :]
predict = predict.argmax(1)
predict = predict.contiguous()
count += 1
label_text = tranfromlabel(labels[i])[:target_lengths[i]]
predict_text = ctc_to_str(predict)
if label_text == predict_text:
correct += 1
predict_lengths = torch.IntTensor([[int(predicts.shape[0])] * labels.shape[0]])
loss = loss_func(predicts, labels, predict_lengths, target_lengths)
lr = optimizer.param_groups[0]['lr']
bar.set_description("Valid epoch %d, acc %.4f, loss %.4f, lr %.6f" % (
epoch, correct / count, loss.detach().cpu().numpy(), lr
))
scheduler.step(correct / count)
torch.save(model.state_dict(), "models/save_%d.model" % epoch)