import torch.nn as nn
import torch
from sklearn.datasets import load_iris
data = load_iris()
X1 = torch.tensor(data.data).float()
X2 = 0.01 * torch.rand(150, 2).float()
y = torch.tensor(data.target).long()
W = torch.randn(3, X1.shape[1], X2.shape[1], requires_grad=True) # 鸢尾花3各类别
b = torch.rand(150, 1, requires_grad=True)
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam([W, b])
for i in range(50000):
optimizer.zero_grad()
Y = torch.cat([torch.sum(torch.mm(X1, W[i,:,:]) * X2, dim=1).unsqueeze(1) for i in range(W.shape[0])], dim=1) + b
loss = cost(Y, y)
loss.backward()
optimizer.step()
print(loss)
_, id = torch.max(Y, 1)
correct = 0
for i in range(len(id)):
if id[i] == y[i]:
correct += 1
print(correct / 150)