Read Me
是B站刘二大人的 pytorch的P9的作业。前几讲的作业都是仅看懂了别人的代码。这个代码是自己独立写的,结构、数据类型之类的改了不少bug。欢迎大家一起交流哦,我在评论区等你们!
结构
1、prepare dataset
2、design model using class
3、construct loss and optimizer
4、training cycle
5、write in submission.csv
代码
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn import preprocessing
import torch.optim as optim
batch_size = 64
# ---------------------------- 1、prepare dataset ----------------------------#
# 数据预处理
def prepare_train_data(df_data):
df_data['target'] = df_data['target'].map({'Class_1': '0',
'Class_2': '1',
'Class_3': '2',
'Class_4': '3',
'Class_5': '4',
'Class_6': '5',
'Class_7': '6',
'Class_8': '7',
'Class_9': '8'}).astype(int)
features = df_data.values[:, 1:94]
label = df_data.values[:, 94]
norm_features = preprocessing.minmax_scale(features, feature_range=(0, 1))
norm_features = norm_features.astype(np.float32) # 网络输入类型应为float32
label = label.astype(np.float32)
return norm_features, label
def prepare_test_data(df_data):
features = df_data.values[:, 1:]
norm_features = preprocessing.minmax_scale(features, feature_range=(0, 1))
norm_features = norm_features.astype(np.float32)
return norm_features
# 定义数据集格式
class OttoDataset(Dataset):
def __init__(self, data_path):
train_data = pd.read_csv(data_path)
self.len = train_data.shape[0]
self.train_data_x, self.train_data_y = prepare_train_data(train_data)
def __getitem__(self, index):
return self.train_data_x[index], self.train_data_y[index]
def __len__(self):
return self.len
dataset = OttoDataset('/home/llz/PycharmProjects/PytorchClass/L9_MultiClassify/L9_hw_otto/train.csv')
train_loader = DataLoader(dataset=dataset,
shuffle=True,
batch_size=batch_size,
num_workers=2)
# ---------------------------2、design model using class----------------------------#
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(93, 64) # 类似 维度变换
self.linear2 = torch.nn.Linear(64, 32)
self.linear3 = torch.nn.Linear(32, 16)
self.linear4 = torch.nn.Linear(16, 9)
self.activate1 = torch.nn.ReLU()
self.activate2 = torch.nn.Sigmoid()
def forward(self, x):
x = self.activate1(self.linear1(x))
x = self.activate1(self.linear2(x))
x = self.activate2(self.linear3(x))
return self.linear4(x)
model = Model()
# ----------------------------3、construct loss and optimizer----------------------------#
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# ----------------------------4、train cycle ------------------------------------------#
for epoch in range(100):
for i, (inputs, labels) in enumerate(train_loader, 0):
# 1 Forward
y_pred = model(inputs)
#y_pred = torch.as_tensor(y_pred, dtype=torch.float32)
loss = criterion(y_pred, labels.long()) # .long()此参数要求是此类型的
# 2 Backward
optimizer.zero_grad()
loss.backward()
# 3 Update
optimizer.step()
# ---------------------------5、write in submission.csv----------------------------#
test_data_path = "/home/llz/PycharmProjects/PytorchClass/L9_MultiClassify/L9_hw_otto/test.csv"
test_data = pd.read_csv(test_data_path)
test_data_x = prepare_test_data(test_data)
test_data_x = torch.from_numpy(test_data_x)
test_data_y = model(test_data_x)
test_data_y = test_data_y.detach().numpy() # 此时的test_data_y 是正负小数
# softmax = torch.nn.Softmax(dim=1)
# output_y = softmax(test_data_y) 不用写,因为softmax函数单调增
predicted = torch.max(torch.tensor(test_data_y), dim=1)[1].numpy() # torch.tensor 转化成这个形式才能用torch.max
output_y = np.zeros((test_data_y.shape[0], test_data_y.shape[1]))
for i in range(test_data_y.shape[0]):
output_y[i][predicted[i]] = 1
output_y = output_y.astype(int)
test_data_Id = test_data['id']
test_df = pd.DataFrame({'id': test_data_Id,
'Class_1': output_y[:, 0],
'Class_2': output_y[:, 1],
'Class_3': output_y[:, 2],
'Class_4': output_y[:, 3],
'Class_5': output_y[:, 4],
'Class_6': output_y[:, 5],
'Class_7': output_y[:, 6],
'Class_8': output_y[:, 7],
'Class_9': output_y[:, 8]})
test_df.to_csv('/home/llz/PycharmProjects/PytorchClass/L9_MultiClassify/L9_hw_otto/submission.csv', index=False)