CNN分类网络之:Lenet

      pytorch 写的一个 lenet的分类网络,不是百分百还原哈,结构是一样的, 简单训练一下自己的数据集。数据集格式如下,data内存放 自己的数据,每个类别放到一个文件夹中,文件夹名称为类别标签如下图 

                                                             

1.网络搭建

import torch
import torch.nn as nn


class Lenet(nn.Module):
    def __init__(self, num_classes = 1000):
        super(Lenet, self).__init__()
        self.conv1 = nn.Conv2d(3,6,5,1,0)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(6, 16, 5, 1, 0)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = nn.Conv2d(16, 120, 5, 1, 0)
        self.fc6 = nn.Linear(120, 84)
        self.fc7 = nn.Linear(84, 10)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = self.relu(x)
        x = self.pool4(x)

        x = self.conv5(x)
        x = self.relu(x)

        x = x.view(x.size()[0], -1)

        x = self.fc6(x)
        x = self.relu(x)
        x = self.fc7(x)
        return x


2、读取自己的图像及标签数据

from torch.utils.data import Dataset
import cv2
import os
from torchvision import transforms as tvtsf
import torch
import numpy as np

class cnndata(Dataset):
    def __init__(self, path):
        self.path = path
        self.img, self.cls, self.class_label= self.get_filname_and_cls()
        print("图像共有:%s 张"%(len(self.img)))

    def __getitem__(self, item):
        img = cv2.imread(self.img[item], 1)
        img_src = cv2.resize(img, (32, 32), cv2.INTER_AREA)

        if img_src.shape[2] == 1:
            img_src = cv2.cvtColor(img_src, cv2.COLOR_GRAY2BGR)
        image = torch.from_numpy(img_src / 255.)
        image = image.permute(2, 0, 1).contiguous()
        label = torch.from_numpy(np.array(self.cls[item]))

        return image, label

    def get_filname_and_cls(self):
        if not os.path.exists(self.path):
            raise Exception("no wenjianjia")

        class_name = os.listdir(self.path)

        cls = []
        imgs = []
        class_label = {}
        for c, cl in enumerate(class_name):
            class_label[c] = cl
            filename = os.listdir(os.path.join(self.path, cl))
            filename.sort()
            for name in filename:
                # img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
                imgs.append(os.path.join(self.path, cl, name))
                cls.append(c)

        return imgs, cls, class_label

    def get_classlabel(self):
        return self.class_label

    def __len__(self):
        return len(self.img)

3、 训练及测试

from get_data import cnndata
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import optim
import torch
import time
from Lenet import Lenet


def train(batch=16, traindata=None, model= None, epochs=20):
    dataloader = DataLoader(traindata, batch_size=int(batch), shuffle=True,num_workers=2)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    for epoch in range(int(epochs)):
        star = time.time()
        correct_num = 0
        epoch_loss = 0
        for img, cls in dataloader:
            img= img.type(torch.FloatTensor).cuda()
            cls = cls.type(torch.LongTensor).cuda()

            optimizer.zero_grad()

            out = model(img).cuda()
            loss = criterion(out, cls)

            epoch_loss += loss.detach().cpu().numpy()
            correct_num += torch.eq(cls, out.argmax(dim=1)).sum().item()

            loss.backward()
            optimizer.step()

        print("Epoch :%d , loss%.4f, acc:%.3f, time:%.3f"
              %(epoch, round(epoch_loss/len(traindata), 4), round(correct_num/(len(traindata)), 3), time.time()-star))
        if correct_num/len(traindata) > 0.8:
            torch.save(model.state_dict(), 'lenet.pth')


def test(batch=16, traindata=None, model= None, class_label =None):
    dataloader = DataLoader(traindata, batch_size=int(batch), shuffle=True, num_workers=2)
    model.load_state_dict(torch.load('lenet.pth'))
    model.eval()
    with torch.no_grad():
        star = time.time()
        for img, cls in dataloader:
            img = img.type(torch.FloatTensor).cuda()
            cls = cls.type(torch.LongTensor).cuda()
            out = model(img).cuda()
            pre = out.argmax(dim=1)
            print('pre is:', pre.detach().cpu().numpy(),'label is:', cls.detach().cpu().numpy())


def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.zero_()
        
        
if __name__ == '__main__':
    batch = 2
    traindata = cnndata(path="./data")

    model = Lenet(num_classes=10).cuda()
    model.apply(weights_init)
    class_label =traindata.get_classlabel()
    print('index and labels:', class_label)
    train(batch=batch, traindata=traindata, model=model)
    # test(batch=batch, traindata=traindata, model=model, class_label=class_label)

  训练结果:

 

预测结果:

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值