需要的包
import torch
import torch.nn as nn
import torch.nn.functional as F
模拟的输入x变量:4分类问题
batch_size, n_classes = 10, 4
x = torch.randn(batch_size, n_classes)
x.shape
x维度
torch.Size([10, 4])
运行:
x
Out:
tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],
[ 0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[ 1.3081, 0.6641, 1.1802, -0.2547],
[ 0.5292, 0.7636, 0.3692, -0.8318],
[ 0.5100, 0.9849, -1.2905, 0.2821],
[ 1.4662, 0.4550, 0.9875, 0.3143],
[-1.2121, 0.1262, 0.0598, -1.6363],
[ 0.3214, -0.8689,