PyTorch demo——基于MLP的鸢尾花分类

系统框架

在这里插入图片描述

1. 数据集加载

  继承torch.utils.data.Dataset类,重写__getitem__和__len__方法,并在__getitem__中预处理数据。

# load.py
import torch


class IrisDataset(torch.utils.data.Dataset):
    def __init__(self, data_file, iris_class):
        super(IrisDataset, self).__init__()

        self.iris_class = iris_class

        self.all_data = []
        with open(data_file, 'r') as f:
            lines = f.readlines()
            lines = [line.rstrip() for line in lines]
            for l in lines:
                l = l.split(',')
                vec = [float(i) for i in l[:-1]]
                label = self.iris_class[str(l[-1])]
                self.all_data.append([vec, label])


    def __getitem__(self, item):
        fea, label = self.all_data[item]
        fea, label = torch.tensor(fea, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
		# No data augmentation
		
        return fea, label


    def __len__(self):

        return len(self.all_data)


if __name__ == "__main__":
    import config
    dataset = IrisDataset("iris/train", config.iris_class)
    print(dataset.__getitem__(0))

2. 网络模型——MLP

在这里插入图片描述

# net.py
import torch
import torch.nn as nn


class Net(torch.nn.Module):
    def __init__(self, input_dim=4, num_class=3):
        super(Net, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Linear(64, num_class),
            nn.Softmax()
        )

    def forward(self, x):

        return self.fc(x)


if __name__ == "__main__":
    net = Net()
    print(net)

    x = torch.randn(2, 4)
    print(net(x).shape)

3. 配置文件——网络参数、训练参数整理

# config.py
import warnings
warnings.filterwarnings('ignore')

"""dataset"""
iris_class = {
    "Iris-setosa": 0,
    "Iris-versicolor": 1,
    "Iris-virginica": 2
}

"""net args"""
input_dim = 4
num_class = 3

"""train & valid"""
train_data = 'iris/train'
valid_data = 'iris/valid'
batch_size = 10
nworks = 1
max_epoch = 200
lr = 1e-3
factor = 0.9

""" test """
test_data = "iris/test"
pre_model = "pth/model_100.pth"

4. 训练

# train.py
import torch, os, tqdm
from torch.utils.data import DataLoader

import load, net, config
import matplotlib.pyplot as plt


def train():
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda:" + str(0))   
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

    # load dataset for train and eval
    train_dataset = load.IrisDataset(config.train_data, config.iris_class)
    train_batchs = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.nworks, pin_memory=True)
    valid_dataset = load.IrisDataset(config.valid_data, config.iris_class)
    valid_batchs = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.nworks, pin_memory=True)

    model = net.Net(config.input_dim, config.num_class)
    model = model.to(DEVICE)

    loss_criterion = torch.nn.CrossEntropyLoss()    
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

    os.makedirs("pth", exist_ok=True)

    plt.ion()
    train_loss, valid_loss, valid_acc = [], [], []

    for epoch in tqdm.tqdm(range(1, config.max_epoch+1)):

        optimizer.param_groups[0]['lr'] = config.lr * ((1 - (epoch-1)/ config.max_epoch)**config.factor)

        """ train """
        model.train()
        total_loss=0
        for batch, (fea, target) in enumerate(train_batchs):
            fea, target = fea.to(DEVICE), torch.nn.functional.one_hot(target, 3).float().to(DEVICE)
            pred = model(fea)

            loss = loss_criterion(pred, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        train_epoch_loss = total_loss / len(train_dataset)*config.batch_size
        # print("epoch",epoch,"loss:", train_epoch_loss)
        torch.save(model.state_dict(), os.path.join("pth", 'model_' + str(epoch) + '.pth'))

        """ valid """
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            total_loss = 0
            for fea, labels in valid_batchs:
                labels = labels.to(DEVICE)
                fea, target = fea.to(DEVICE), torch.nn.functional.one_hot(labels, 3).float().to(DEVICE)

                pred = model(fea)

                loss = loss_criterion(pred, target)
                total_loss += loss.item()

                _, predicted = torch.max(pred.data, dim=1)
                total += labels.size(0)
                correct += (predicted == labels).sum()

            valid_epoch_loss = total_loss / len(valid_dataset) * config.batch_size
            # print('Accuracy test set: %d%%' % (100 * (correct / total)))

        train_loss.append(train_epoch_loss)
        valid_loss.append(valid_epoch_loss)
        valid_acc.append(correct.cpu() / total)

        plt.clf()
        plt.plot(train_loss, color='black', label="train loss")
        plt.plot(valid_loss, color='red', label="valid loss")
        plt.plot(valid_acc, color='green', label="valid acc")
        plt.grid()
        plt.legend()
        plt.savefig("train.jpg")

    plt.ioff()
    plt.close()


if __name__ == '__main__':

    train()

训练过程可视化
在这里插入图片描述

5.测试

# test.py
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report

import net, load, config


def test():
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda:" + str(0))
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

    # load dataset
    test_dataset = load.IrisDataset(config.test_data, config.iris_class)
    test_batchs = DataLoader(test_dataset, batch_size=10, shuffle=False,
                             num_workers=0, pin_memory=True)

    # model
    model = net.Net(config.input_dim, config.num_class)
    model.load_state_dict(torch.load(config.pre_model, map_location='cpu'), strict=False)
    model = model.to(DEVICE)

    # test
    model.eval()
    with torch.no_grad():
        preds, labels = [], []
        for i, (fea, label) in enumerate(test_batchs):

            pred = model(fea.to(DEVICE))
            _, predicted = torch.max(pred.data, dim=1)

            preds.append(predicted)
            labels.append(label)

    # report
    preds = torch.stack(preds, dim=0).view(-1).cpu().numpy()
    labels = torch.stack(labels, dim=0).view(-1).numpy()

    report = classification_report(labels, preds, target_names = config.iris_class.keys())
    print(report)


if __name__ == '__main__':

    test()

测试集结果
在这里插入图片描述

6.文件结构

在这里插入图片描述

6.1 requirements.txt

matplotlib==3.7.2
scikit_learn==1.3.2
torch==2.0.0+cu118
tqdm==4.65.2

6.2附已划分的数据集

训练集——iris/train
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
验证集——iris/valid
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
测试集——iris/test
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值