pytorch实现三分类
代码:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
# 生成随机数据
num_samples = 300
input_size = 28 * 28
num_classes = 3
data = np.random.randn(num_samples, input_size)
labels = np.random.randint(0, num_classes, num_samples)
# 划分训练集和测试集
train_data, test_data = data[:200], data[200:]
train_labels, test_labels = labels[:200], labels[200:]
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
train_dataset = CustomDataset(train_data, train_labels)
test_dataset = CustomDataset(test_data, test_labels)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
class FullyConnectedNN(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(FullyConnectedNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
hidden_size = 128
model = FullyConnectedNN(input_size, hidden_size, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练
num_epochs = 20
loss_values = []
for epoch in range(num_epochs):
epoch_loss = 0
for i, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.float(), labels.long()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss /= len(train_loader)
loss_values.append(epoch_loss)
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss}")
# 测试
with torch.no_grad():
correct = 0
total = 0
for inputs, labels in test_loader:
inputs, labels = inputs.float(), labels.long()
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy: {100 * correct / total}%")
# 绘制损失曲线
plt.plot(loss_values)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.show()
运行截图: