pytorch不定长文本识别resnet18+LSTM

pytorch不定长文本识别resnet18+LSTM

import torch
from torch import nn
from torch.nn import LSTM, Linear
from torchvision import models
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
import os
import cv2
import numpy as np
from tqdm import tqdm
import torchvision.transforms as T
// 图片大小
IMAGE_SHAPE = (28, 135)

transform = T.Compose([
    T.ToPILImage(),
    T.Resize(IMAGE_SHAPE),
    T.ToTensor()
])
// 标签'_'代表占位,不定长必要
LABEL_MAP = [i for i in '_0123456789-+=']
Max_label_len = 6


class MyDataset(Dataset):
    def __init__(self, data_path, label_map, max_label_len):
        super(MyDataset, self).__init__()
        self.data = [(os.path.join(data_path, file), file.split('_')[0]) for file in os.listdir(data_path)]
        self.label_map = [char for char in label_map]
        self.label_map_len = len(self.label_map)
        self.max_label_len = max_label_len

    def __getitem__(self, index):
        file = self.data[index][0]
        label = self.data[index][1]
        raw_len = len(label)
        im = np.fromfile(file, dtype=np.uint8)
        im = cv2.imdecode(im, cv2.IMREAD_COLOR)
        im = transform(im)
        label = [self.label_map.index(i) for i in label]
        for i in range(self.max_label_len - len(label)):
            label.append(0)
        label = np.asarray(label, dtype='int32').reshape(self.max_label_len)

        return im, label, raw_len

    def __len__(self):
        return len(self.data)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 需要把后边的AdaptiveAvgPool2d层和全连接层去掉
        self.resnet18 = nn.Sequential(*list(models.resnet18().children())[0:-3])
        bone_output_shape = self._cal_shape()
        self.lstm = LSTM(bone_output_shape, bone_output_shape, num_layers=1, bidirectional=True)
        self.linear = Linear(bone_output_shape * 2, 256)
        self.lstm1 = LSTM(256, bone_output_shape, num_layers=1, bidirectional=True)
        self.linear1 = Linear(bone_output_shape * 2, len(LABEL_MAP))

    def _cal_shape(self):
        x = torch.zeros((1, 3) + IMAGE_SHAPE)
        shape = self.resnet18(x).shape  # [1, 256, 4, 10] BATCH, DIM, HEIGHT, WIDTH
        return shape[1] * shape[2]

    def forward(self, x):
        x = self.resnet18(x)
        x = x.permute(3, 0, 1, 2)
        w, b, c, h = x.shape
        x = x.view(w, b, c * h)
        x, _ = self.lstm(x)
        time_step, batch_size, h = x.shape
        x = x.view(time_step * batch_size, h)
        x = self.linear(x)
        x = x.view(time_step, batch_size, -1)

        x, _ = self.lstm1(x)
        time_step, batch_size, h = x.shape
        x = x.view(time_step * batch_size, h)
        x = self.linear1(x)
        x = x.view(time_step, batch_size, -1)
        return x


def tranfromlabel(label):
    t_label = []
    for i in label:
        t_label.append(LABEL_MAP[i])
    return ''.join(t_label)


def ctc_to_str(data):
    """
    CTC 解码
    :param data: 编码后的文本
    :param label_map: 码表
    :return: 解码后文本
    """
    # print('in',data)
    result = []
    last = -1
    for i in list(data):
        if i == 0:
            last = -1
        elif i != last:
            result.append(i)
            last = i
    return tranfromlabel(result)


train = DataLoader(
    dataset=MyDataset(r'./train', label_map=LABEL_MAP, max_label_len=Max_label_len),
    batch_size=32, shuffle=True,
    num_workers=3)
test = DataLoader(
    dataset=MyDataset(r'./test', label_map=LABEL_MAP, max_label_len=Max_label_len),
    batch_size=4, shuffle=True,
    num_workers=0)

if __name__ == '__main__':
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Net()
    model.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_func = nn.CTCLoss()
    scheduler = ReduceLROnPlateau(optimizer, 'max', patience=3)
    for epoch in range(0, 100):
        bar = tqdm(train, 'Training')
        for images, labels, target_lengths in bar:
            images = images.to(DEVICE)
            predict = model(images)
            predict_lengths = torch.IntTensor([[int(predict.shape[0])] * labels.shape[0]])
            loss = loss_func(predict, labels, predict_lengths, target_lengths)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr = optimizer.param_groups[0]['lr']
            bar.set_description("Train epoch %d, loss %.4f, lr %.6f" % (
                epoch, loss.detach().cpu().numpy(), lr
            ))

        bar = tqdm(test, 'Validating')
        correct = count = 0
        for images, labels, target_lengths in bar:
            images = images.to(DEVICE)
            predicts = model(images)

            for i in range(predicts.shape[1]):
                predict = predicts[:, i, :]
                predict = predict.argmax(1)
                predict = predict.contiguous()
                count += 1
                label_text = tranfromlabel(labels[i])[:target_lengths[i]]
                predict_text = ctc_to_str(predict)
                # print(label_text, predict_text)
                if label_text == predict_text:
                    correct += 1

            predict_lengths = torch.IntTensor([[int(predicts.shape[0])] * labels.shape[0]])

            loss = loss_func(predicts, labels, predict_lengths, target_lengths)

            lr = optimizer.param_groups[0]['lr']
            bar.set_description("Valid epoch %d, acc %.4f, loss %.4f, lr %.6f" % (
                epoch, correct / count, loss.detach().cpu().numpy(), lr
            ))

        scheduler.step(correct / count)
        torch.save(model.state_dict(), "models/save_%d.model" % epoch)
  • 0
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
抱歉,我不太确定你想要的是什么。ResNet和BiLSTM是两种不同的神经网络结构,它们在图像处理和自然语言处理方面都有广泛的应用。如果你想要生成ResNet和BiLSTM的代码,你需要先确定你想要用这些网络来解决什么问题,然后选择相应的框架(如TensorFlow或PyTorch)来实现它们。以下是一个简单的示例代码,演示如何在PyTorch中实现ResNet和BiLSTM。注意,这只是一个示例,具体实现可能会因问题和需求而异。 ```python import torch import torch.nn as nn # ResNet class ResBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super(ResBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): identity = x out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(identity) out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, num_classes=10): super(ResNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(64, 2) self.layer2 = self._make_layer(128, 2, stride=2) self.layer3 = self._make_layer(256, 2, stride=2) self.layer4 = self._make_layer(512, 2, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512, num_classes) self.relu = nn.ReLU(inplace=True) def _make_layer(self, out_channels, blocks, stride=1): layers = [] layers.append(ResBlock(64, out_channels, stride)) for i in range(1, blocks): layers.append(ResBlock(out_channels, out_channels)) return nn.Sequential(*layers) def forward(self, x): out = self.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = self.avgpool(out) out = out.view(out.size(0), -1) out = self.fc(out) return out # BiLSTM class BiLSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers, num_classes): super(BiLSTM, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True) self.fc = nn.Linear(hidden_size*2, num_classes) def forward(self, x): h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) out, _ = self.lstm(x, (h0, c0)) out = self.fc(out[:, -1, :]) return out ``` 上述代码实现了一个ResNet和一个BiLSTM,分别用于图像分类和序列标注任务。ResNet包含多个ResBlock,每个ResBlock由两个卷积层和一个shortcut连接组成。BiLSTM包含多个LSTM层和一个全连接层,用于从序列中提取特征并进行分类。这只是一个简单的示例,实际应用中可能需要对这些模型进行更多的调整和改进。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值