基于CNN+LSTM+CTC的不定长数字仪表识别(二)

上篇文章介绍了搭建模型的过程,本期把剩下的训练与dataset的制作更新一下。
首先是制作我们的dataset,我们的图片与标签是分开的,图片的命名是批量按顺序命名,标签全部放在一个txt文档中,如下图所示数据标签
图片大小是不一致的,这里我们可以采用两种思想去处理,便于网络的输入,一是直接了当的裁剪方法,本文也是将图像统一裁剪为(32,100)的形式,第二种是采用池化金字塔的方法,核心思想就是利用不同大小的卷积核得到最后统一特征维数的输出,代码可以参见本篇博客:池化金字塔SPP-NET随后就是代码实现啦,框架用的pytorch。

from PIL import Image
import torch
from torch.utils.data import Dataset
import os

def img_read(im_path):
    img = Image.open(im_path)
    img = img.resize((100,32),Image.ANTIALIAS)
    return img

class mydatasets(Dataset):
    def __init__(self,txt_path,transform = None,target_trans = None):
        f = open(txt_path,"r")
        samples = []
        fixpath = "../electricity_project/num_data/"
        for line in f:
            line = line.rstrip()
            sample = line.split()
            sample[0] = os.path.join(fixpath,sample[0])
            #target = sample[1]
            target = []
            for char in sample[1]:
                target.append(char)
            samples.append((sample[0],int(target)))
        self.txt_path = txt_path
        self.transform = transform
        self.target_transform = target_trans
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        im_path,lable = self.samples[index]
        img = img_read(im_path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            lable = self.target_transform(lable)
        return img,lable

由于我们用的pytorch的Dataloader去封装我们的dataset,它要求每个batch里面的数据长度一致,而我们的标签长度是不一致的,因此在dataloader里的collate_fn可以自己定义,用padding的方式补偿,最后在压缩到原来的长度,训练代码如下:

from mymodel import CRNN
from datasets import mydatasets
import torch
from torch.utils.data import DataLoader
import os
import torchvision.transforms
from torchvision.transforms import Compose,ToTensor,Resize
import matplotlib.pyplot as plt
model_path = "../electricity_project/model/cnn_lstm_ctc.pth"
if not os.path.exists("../electricity_project/model"):
    os.mkdir("../electricity_project/model")
def cal_acc(output,target):
    #output,target = output.view(-1,64),target.view(-1,64)
    output = torch.softmax(output,dim=1)
    output = torch.argmax(output,dim=1)
    target = torch.argmax(target,dim=1)
    #output,target = output.view(-1,10),target.view(-1,10)
    correct_list = []
    for i,j in zip(target,output):
        if torch.equal(i,j):
            correct_list.append(1)
        else:
            correct_list.append(0)
    acc = sum(correct_list) / len(correct_list)
    return acc


batch_size = 64
lr = 0.01
epoch = 10
restor = False

def train():
    transforms = Compose([ToTensor()])
    train_dataset = mydatasets(txt_path="../electricity_project/num_data/train.txt",transform=transforms)
    train_dataloader = DataLoader(train_dataset,batch_size = batch_size,num_workers=0,
                                  shuffle=True,drop_last=True)
    test_dataset = mydatasets(txt_path="../electricity_project/num_data/test.txt",transform=transforms)
    test_dataloader = DataLoader(test_dataset,batch_size = batch_size,num_workers=0,
                                 shuffle=True,drop_last=True)
    net = CRNN(nclass=10,nhidden=128)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net.to(device)
    if restor:
        net.load_state_dict(torch.load(model_path))
    optimizer = torch.optim.Adam(net.parameters(),lr=lr)
    #ctc_loss = torch.nn.CrossEntropyLoss()
    ctc_loss = torch.nn.CTCLoss(blank=10-1)
    train_loss_mean = []
    train_acc_mean = []
    test_loss_mean = []
    test_acc_mean = []
    for epochs in range(epoch):
        loss_histroy = []
        acc_histroy = []
        net.train()
        print("epoch:{}".format(epochs))
        for img,target in train_dataloader:
            img = torch.tensor(img)
            target = torch.tensor(target)
            # targets = torch.tensor([])
            # for i in range(target.shape[0]):
            #     targets = torch.cat((targets,target),0)
            #print(target.shape)
            if torch.cuda.is_available():
                img = img.cuda()
                target = target.cuda()
            output = net(img)
            output = output.log_softmax(2).detach().requires_grad_()
            print(output)
            print(target.shape)
            input_lengths = torch.tensor([output.shape[0]]*int(output.shape[1]))
            target_lengths = torch.tensor([1]*target.shape[0])
            loss = ctc_loss(output,target,input_lengths,target_lengths)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            acc = cal_acc(output,target)
            acc_histroy.append(float(acc))
            loss_histroy.append(float(loss))
        print("train_loss:{.4} | train_acc:{.4}".format(
            torch.mean(torch.tensor(loss_histroy)),
            torch.mean(torch.tensor(acc_histroy))
        ))
        train_loss_mean.append(torch.mean(torch.tensor(loss_histroy)))
        train_acc_mean.append(torch.mean(torch.tensor(acc_histroy)))
        #test
        loss_histroy = []
        acc_histroy = []
        net.eval()

        for img,target in test_dataloader:
            img = torch.tensor(img)
            target = torch.tensor(target)
            if torch.cuda.is_available():
                img = img.cuda()
                target = target.cuda()
            output = net(img)
            output = output.log_softmax(2).detach().requires_grad_()
            input_lengths = torch.tensor([output.shape[0]] * int(output.shape[1]))
            target_lengths = torch.tensor([1] * target.shape[0])
            loss = ctc_loss(output, target, input_lengths, target_lengths)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            acc = cal_acc(output,target)
            acc_histroy.append(float(acc))
            loss_histroy.append(float(loss))
        print("test_loss:{.4} | test_acc:{.4}".format(
            torch.mean(torch.tensor(loss_histroy)),
            torch.mean(torch.tensor(acc_histroy))
        ))
        test_loss_mean.append(torch.mean(torch.tensor(loss_histroy)))
        test_acc_mean.append(torch.mean(torch.tensor(acc_histroy)))
        torch.save(net.state_dict(),model_path)

    plt.plot(train_loss_mean,"b",label = "train_loss")
    plt.plot(test_loss_mean,"g",label = "train_acc")
    plt.legend()
    plt.xlabel("epochs")
    plt.ylabel("loss")
    plt.show()
    plt.plot(train_acc_mean, "b", label="test_loss")
    plt.plot(test_acc_mean, "g", label="test_acc")
    plt.legend()
    plt.xlabel("epochs")
    plt.ylabel("Accuracy")
    plt.show()

if __name__ == "__main__":
    train()



CTCloss,pytorch里面已有相应的接口,用起来也十分方便,但需要注意它的输入格式,详细说明可以参考pytorch官方文档中对于CTCLOSS的说明:nn.CTCLOSS
本期更新就到这里,大家一起学习进步。

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 12
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值