pytorch快速入门实战篇-FashionMnist案例

参考:官网

1. Working with data

PyTorch has two primitives to work with data: torch.utils.data.DataLoader and torch.utils.data.Dataset. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

PyTorch offers domain-specific libraries such as TorchText, TorchVision, and TorchAudio, all of which include datasets. For this tutorial, we will be using a TorchVision dataset.

The torchvision.datasets module contains Dataset objects for many real-world vision data like CIFAR, COCO (full list here). In this tutorial, we use the FashionMNIST dataset. Every TorchVision Dataset includes two arguments: transform and target_transform to modify the samples and labels respectively.

1.1 Download training data from open datasets.

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

1.2 Download test data from open datasets.

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

Out:


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
We pass the Dataset as an argument to DataLoader. This wraps an iterable over our dataset, and supports automatic batching, sampling, shuffling and multiprocess data loading. Here we define a batch size of 64, i.e. each element in the dataloader iterable will return a batch of 64 features and labels.

batch_size = 64

1.3 Create data loaders.

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
    

Out:

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
Read more about loading data in PyTorch.

2. Creating Models

To define a neural network in PyTorch, we create a class that inherits from nn.Module. We define the layers of the network in the init function and specify how data will pass through the network in the forward function. To accelerate operations in the neural network, we move it to the GPU if available.

2.1 Get cpu or gpu device for training.

device = “cuda” if torch.cuda.is_available() else “cpu”
print(f"Using {device} device")

2.2 Define model

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

Out:

Using cuda device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

Read more about building neural networks in PyTorch.

2.3 Optimizing the Model Parameters

To train a model, we need a loss function and an optimizer.

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and backpropagates the prediction error to adjust the model’s parameters.

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

We also check the model’s performance against the test dataset to ensure it is learning.

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
The training process is conducted over several iterations (epochs). During each epoch, the model learns parameters to make better predictions. We print the model’s accuracy and loss at each epoch; we’d like to see the accuracy increase and the loss decrease with every epoch.

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Out:

Epoch 1
-------------------------------
loss: 2.299968  [    0/60000]
loss: 2.291713  [ 6400/60000]
loss: 2.274773  [12800/60000]
loss: 2.272840  [19200/60000]
loss: 2.265114  [25600/60000]
loss: 2.223690  [32000/60000]
loss: 2.246659  [38400/60000]
loss: 2.205027  [44800/60000]
loss: 2.203506  [51200/60000]
loss: 2.183183  [57600/60000]
Test Error:
 Accuracy: 40.4%, Avg loss: 2.176064

Epoch 2
-------------------------------
loss: 2.179508  [    0/60000]
loss: 2.171398  [ 6400/60000]
loss: 2.122458  [12800/60000]
loss: 2.136343  [19200/60000]
loss: 2.099183  [25600/60000]
loss: 2.031877  [32000/60000]
loss: 2.070053  [38400/60000]
loss: 1.988699  [44800/60000]
loss: 1.989894  [51200/60000]
loss: 1.931872  [57600/60000]
Test Error:
 Accuracy: 58.0%, Avg loss: 1.927841

Epoch 3
-------------------------------
loss: 1.951685  [    0/60000]
loss: 1.920792  [ 6400/60000]
loss: 1.815998  [12800/60000]
loss: 1.850795  [19200/60000]
loss: 1.757668  [25600/60000]
loss: 1.698289  [32000/60000]
loss: 1.727373  [38400/60000]
loss: 1.622224  [44800/60000]
loss: 1.637520  [51200/60000]
loss: 1.541725  [57600/60000]
Test Error:
 Accuracy: 62.5%, Avg loss: 1.558743

Epoch 4
-------------------------------
loss: 1.615998  [    0/60000]
loss: 1.577118  [ 6400/60000]
loss: 1.438216  [12800/60000]
loss: 1.503292  [19200/60000]
loss: 1.393710  [25600/60000]
loss: 1.377801  [32000/60000]
loss: 1.397386  [38400/60000]
loss: 1.316668  [44800/60000]
loss: 1.345000  [51200/60000]
loss: 1.248275  [57600/60000]
Test Error:
 Accuracy: 63.4%, Avg loss: 1.277212

Epoch 5
-------------------------------
loss: 1.348338  [    0/60000]
loss: 1.324071  [ 6400/60000]
loss: 1.170114  [12800/60000]
loss: 1.266121  [19200/60000]
loss: 1.148989  [25600/60000]
loss: 1.163925  [32000/60000]
loss: 1.187885  [38400/60000]
loss: 1.124403  [44800/60000]
loss: 1.159600  [51200/60000]
loss: 1.073207  [57600/60000]
Test Error:
 Accuracy: 64.3%, Avg loss: 1.099506

Done!

Read more about Training your model.

3 Saving Models

A common way to save a model is to serialize the internal state dictionary (containing the model parameters).

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

Out:

Saved PyTorch Model State to model.pth

4. Loading Models

The process for loading a model includes re-creating the model structure and loading the state dictionary into it.

model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))

This model can now be used to make predictions.

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

Out:

Predicted: "Ankle boot", Actual: "Ankle boot"
Read more about Saving & Loading your model.

Total running time of the script: ( 0 minutes 42.173 seconds)

5 Complete Code in python

5.1 Get Data

如果上面通过命令行的方式下载数据集特别慢,可以直接下载本网盘里的数据集,然后放到对应的文件夹下。

链接:https://pan.baidu.com/s/194eBjl2W98VfcSCNe0x8Ow 
提取码:e11r 

解压后放入相应的文件夹下:
在这里插入图片描述

5.2 Train



import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor


# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=False,  # 数据集已下载,如果未下载,download = True
    transform=ToTensor(), # 数据预处理,格式调整
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,  # 数据集已下载,如果未下载,download = True
    transform=ToTensor(), # 数据预处理,格式调整
)



# for X, y in test_dataloader:
#     print(f"Shape of X [N, C, H, W]: {X.shape}")
#     print(f"Shape of y: {y.shape} {y.dtype}")
#     break



# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


# print(model)



def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


if __name__ == '__main__':
    batch_size = 64
    loss_fn = nn.CrossEntropyLoss()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")
    model = NeuralNetwork().to(device)  # 实例化模型
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    # Get cpu or gpu device for training.

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    epochs = 5
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        train(train_dataloader, model, loss_fn, optimizer)
        #test(test_dataloader, model, loss_fn)
    print("Done!")

    # Saving models
    torch.save(model.state_dict(), "model.pth")
    print("Saved PyTorch Model State to model.pth")



5.3 Inference

test.py

import torch
from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits



classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

if __name__ == '__main__':
    model = NeuralNetwork()
    model.load_state_dict(torch.load("model.pth"))
    model.eval()
    #x, y = test_data[0][0], test_data[0][1]  # y=9
    x=[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0118, 0.0039, 0.0000, 0.0000, 0.0275,
          0.0000, 0.1451, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0078, 0.0000,
          0.1059, 0.3294, 0.0431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.4667, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000,
          0.3451, 0.5608, 0.4314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0863,
          0.3647, 0.4157, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0157, 0.0000, 0.2078,
          0.5059, 0.4706, 0.5765, 0.6863, 0.6157, 0.6510, 0.5294, 0.6039,
          0.6588, 0.5490, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0431, 0.5373,
          0.5098, 0.5020, 0.6275, 0.6902, 0.6235, 0.6549, 0.6980, 0.5843,
          0.5922, 0.5647, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000,
          0.0078, 0.0039, 0.0000, 0.0118, 0.0000, 0.0000, 0.4510, 0.4471,
          0.4157, 0.5373, 0.6588, 0.6000, 0.6118, 0.6471, 0.6549, 0.5608,
          0.6157, 0.6196, 0.0431, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0118, 0.0000, 0.0000, 0.3490, 0.5451, 0.3529,
          0.3686, 0.6000, 0.5843, 0.5137, 0.5922, 0.6627, 0.6745, 0.5608,
          0.6235, 0.6627, 0.1882, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0157,
          0.0039, 0.0000, 0.0000, 0.0000, 0.3843, 0.5333, 0.4314, 0.4275,
          0.4314, 0.6353, 0.5294, 0.5647, 0.5843, 0.6235, 0.6549, 0.5647,
          0.6196, 0.6627, 0.4667, 0.0000],
         [0.0000, 0.0000, 0.0078, 0.0078, 0.0039, 0.0078, 0.0000, 0.0000,
          0.0000, 0.0000, 0.1020, 0.4235, 0.4588, 0.3882, 0.4353, 0.4588,
          0.5333, 0.6118, 0.5255, 0.6039, 0.6039, 0.6118, 0.6275, 0.5529,
          0.5765, 0.6118, 0.6980, 0.0000],
         [0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0824,
          0.2078, 0.3608, 0.4588, 0.4353, 0.4039, 0.4510, 0.5059, 0.5255,
          0.5608, 0.6039, 0.6471, 0.6667, 0.6039, 0.5922, 0.6039, 0.5608,
          0.5412, 0.5882, 0.6471, 0.1686],
         [0.0000, 0.0000, 0.0902, 0.2118, 0.2549, 0.2980, 0.3333, 0.4627,
          0.5020, 0.4824, 0.4353, 0.4431, 0.4627, 0.4980, 0.4902, 0.5451,
          0.5216, 0.5333, 0.6275, 0.5490, 0.6078, 0.6314, 0.5647, 0.6078,
          0.6745, 0.6314, 0.7412, 0.2431],
         [0.0000, 0.2667, 0.3686, 0.3529, 0.4353, 0.4471, 0.4353, 0.4471,
          0.4510, 0.4980, 0.5294, 0.5333, 0.5608, 0.4941, 0.4980, 0.5922,
          0.6039, 0.5608, 0.5804, 0.4902, 0.6353, 0.6353, 0.5647, 0.5412,
          0.6000, 0.6353, 0.7686, 0.2275],
         [0.2745, 0.6627, 0.5059, 0.4078, 0.3843, 0.3922, 0.3686, 0.3804,
          0.3843, 0.4000, 0.4235, 0.4157, 0.4667, 0.4706, 0.5059, 0.5843,
          0.6118, 0.6549, 0.7451, 0.7451, 0.7686, 0.7765, 0.7765, 0.7333,
          0.7725, 0.7412, 0.7216, 0.1412],
         [0.0627, 0.4941, 0.6706, 0.7373, 0.7373, 0.7216, 0.6706, 0.6000,
          0.5294, 0.4706, 0.4941, 0.4980, 0.5725, 0.7255, 0.7647, 0.8196,
          0.8157, 1.0000, 0.8196, 0.6941, 0.9608, 0.9882, 0.9843, 0.9843,
          0.9686, 0.8627, 0.8078, 0.1922],
         [0.0000, 0.0000, 0.0000, 0.0471, 0.2627, 0.4157, 0.6431, 0.7255,
          0.7804, 0.8235, 0.8275, 0.8235, 0.8157, 0.7451, 0.5882, 0.3216,
          0.0314, 0.0000, 0.0000, 0.0000, 0.6980, 0.8157, 0.7373, 0.6863,
          0.6353, 0.6196, 0.5922, 0.0431],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000]]]
    x=torch.tensor(x)
    #y=9
    with torch.no_grad():
        pred = model(x)
        #predicted, actual = classes[pred[0].argmax(0)], classes[y]
        predicted= classes[pred[0].argmax(0)]
        print(f'Predicted: "{predicted}"')

out:

Predicted: "Ankle boot"

6. Reference

  1. 官网:https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

7. others

完整的数据集和代码可从该处下载:https://download.csdn.net/download/weixin_39107270/85213879

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

闪闪发亮的小星星

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值