【PyTorch深度学习实践】P9 kaggle otto商品分类作业(含注释)

《PyTorch深度学习实践》-刘二大人 Otto Group Product Classification作业
将商品进行十分类,输入为93个特征10个类别的商品数据集,输出为预测数据集里商品是哪个类别
在这里插入图片描述

数据集可以在https://www.kaggle.com/c/otto-group-product-classification-challenge下载

代码及注释如下

import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
import torch.optim as optim
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 函数将字符型标签转换为数值标签,方便计算交叉熵
def lables2id(lables):
    target_id = []
    target_lables = ['Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5', 'Class_6', 'Class_7', 'Class_8', 'Class_9']
    for lable in lables:
        target_id.append(target_lables.index(lable))
    return target_id


# 定义数据集类
class TrainDataset(Dataset):
    def __init__(self, filepath):
        data = pd.read_csv(filepath)
        lables = data['target']
        self.len = data.shape[0]  # shape(多少行,多少列)
        self.x_data = torch.tensor(np.array(data)[:, 1:-1].astype(float))
        self.y_data = lables2id(lables)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len


train_dataset = TrainDataset('D:/Research/Deep learning/pytorch刘二大人/otto-group-product-classification-challenge/train.csv')
# 数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=0)


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(93, 64)
        self.linear2 = torch.nn.Linear(64, 32)
        self.linear3 = torch.nn.Linear(32, 16)
        self.linear4 = torch.nn.Linear(16, 9)
        self.activate = torch.nn.ReLU()

    def forward(self, x):
        x = self.activate(self.linear1(x))
        x = self.activate(self.linear2(x))
        x = self.activate(self.linear3(x))
        x = self.linear4(x)
        return x


model = Model()

#定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.6)

loss_list = []


def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader):
        inputs, target = data
        inputs = inputs.float()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, target)

        loss_list.append(loss.item())

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 300 == 299:  # 每300轮打印一次结果
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0


# 开始训练
if __name__ == '__main__':
    for epoch in range(50):
        train(epoch)


# 预测保存函数,用于保存预测结果
def predict_save():
    with torch.no_grad():
      test_data = pd.read_csv('D:/Research/Deep learning/pytorch刘二大人/otto-group-product-classification-challenge/test.csv')
      x_text = torch.tensor(np.array(test_data)[:, 1:].astype(float))
      y_pred = model(x_text.float())
      _, predicted = torch.max(y_pred, dim=1)  # 这里先取出最大概率的索引,即是所预测的类别。
      out = pd.get_dummies(predicted)  # get_dummies 利用pandas实现one hot encode,方便保存为预测文件。

      lables = ['Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5', 'Class_6', 'Class_7', 'Class_8', 'Class_9']
      # 添加列标签
      out.columns = lables
      # 插入id行
      out.insert(0, 'id', test_data['id'])
      result = pd.DataFrame(out)
      result.to_csv('my_predict.csv', index=False)


    #画损失函数曲线
    plt.plot(range(len(loss_list)), loss_list)
    plt.xlabel('step')
    plt.ylabel('loss')
    plt.show()

predict_save()



输出结果
[1, 300] loss: 1.424
[1, 600] loss: 0.863
[1, 900] loss: 0.742

[25, 300] loss: 0.467
[25, 600] loss: 0.461
[25, 900] loss: 0.477

[50, 300] loss: 0.421
[50, 600] loss: 0.423
[50, 900] loss: 0.426

损失函数曲线如下
在这里插入图片描述

可以尝试不同的optimizer,参数,进一步处理数据等等再优化。

  • 6
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值