分类问题:通常为多输出
![](https://img-blog.csdnimg.cn/direct/661f649107dd44379e0f542cb620ad74.png)
https://blog.csdn.net/bitcarmanlee/article/details/82320853?ops_request_misc={"request_id"%3A"168093679016800227445000"%2C"scm"%3A"20140713.130102334.."}&request_id=168093679016800227445000&biz_id=0&utm_medium=distribute.wap_search_result.none-task-blog-2~all~top_positive~default-1-82320853-null-null.wap_first_rank_v2_rank_v29&utm_term=softmax函数&spm=1018.2118.3001.4187
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()#用svg显示图片
trans = transforms.ToTensor()
train = torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
test = torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True)
# print(len(train))
# print(len(test))
##返回文本标签
def get_label(label):
text_labels = ['t_shirt','trouser','pullover','dress','coat','sandal','sandal'
'sneaker','bag','ankle ','boot']
return [text_labels[int(i)] for i in label]
def zhanshitupian(imgs,rows,cols,titles = None,scale = 1.5):
size = (cols*scale,rows*scale)
_,axes = d2l.plt.subplots(rows,cols,figsize=size)
axes = axes.flatten()#flatten()是对多维数据的降维函数
#<https://blog.csdn.net/kuan__/article/details/116987162?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522168095485216800197045885%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=168095485216800197045885&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_positive~default-1-116987162-null-null.142^v82^koosearch_v1,201^v4^add_ask,239^v2^insert_chatgpt&utm_term=flatten%EF%BC%88%EF%BC%89&spm=1018.2226.3001.4187>
for i ,(ax,img) in enumerate(zip(axes,imgs)):#枚举、列举 用于将一个可遍历的数据对象(如:列表、元组、字符串等)组合为一个索引序列,同时列出:数据和数据下标
if torch.is_tensor(img):#如果传递的对象是PyTorch张量,则方法返回True。
ax.imshow(img.numpy())
ax.axis('off')#关闭所有坐标轴线、刻度标记和标签
ax.set_title(titles[i])#用于为图形设置标题
#<https://blog.csdn.net/muyimo/article/details/79697905?ops_request_misc=&request_id=&biz_id=&utm_medium=distribute.pc_search_result.none-task-blog-2~all~koosearch~default-2-79697905-null-null.142^v82^koosearch_v1,201^v4^add_ask,239^v2^insert_chatgpt&utm_term=axis%28off%29&spm=1018.2226.3001.4187>
else:
ax.imshow(img)
return axes
x,y = next(iter(data.DataLoader(train,batch_size=18)))
zhanshitupian(x.reshape(18,28,28),2,9,titles=get_label(y))
#d2l.plt.show()##显示图片
##读小批量数据
batch=256
def duojinchen():#多进程数据读取
return 4
trainshuju = data.DataLoader(train,batch,shuffle=True,num_workers=duojinchen())
timer = d2l.Timer()
for x,y in trainshuju:
continue
print(f'{timer.stop():.2f} sec')
#这是Python中的一种字符串格式化方法,表示将一个计时器对象的停止时间格式化为保留两位小数的秒数字符串。
#具体来说,假设有一个计时器对象名为timer,可以使用timer.stop()方法停止计时器并返回运行时间,
# 这个时间可以是一个浮点数。使用f-string字符串格式化方法,可以将这个浮点数格式化为保留两位小数的字符串,
# 例如f'{timer.stop():.2f}'。最终的结果就是一个字符串,表示计时器运行的时间,单位为秒,保留两位小数。
##下载数据集,并加载到内存里
def load_mnist(batch,resize=None):
trans = [transforms.ToTensor()]
if resize :
trans.insert(0,transforms.Resize(resize))
trans=transforms.Compose(trans)
train = torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=False)
test = torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=False)
return (data.DataLoader(train,batch,shuffle=True,num_workers=duojinchen())),
(data.DataLoader(test,shuffle=True,num_workers=duojinchen()))
softmax从零实现
####softmax从零实现
import torch
from IPython import display
from d2l import torch as d2l
batch = 256
train,test = d2l.load_data_fashion_mnist(batch)
input = 784
output = 10
# 原始数据集中的每个样本都是28*28的图像。 本节将展平每个图像,把它们看作长度为784的向量。
#输出与类别一样多。 因为我们的数据集有10个类别,所以网络输出维度为10。
w= torch.normal(0,0.01,size=(input,output),requires_grad=True)
b = torch.zeros(output,requires_grad=True)
#矩阵求和
# x = torch.tensor([[1.0,2.0,3.0],
# [4.0,5.0,6.0]])
# print(x.sum(0,keepdim=True))#对矩阵行进行求和
# print(x.sum(1,keepdim = True))#对列求和
def softmax(x):
x_e = torch.exp(x)
fenmu = x_e.sum(1,keepdim = True)
return x_e/fenmu
def net(x):
return softmax(torch.matmul(x.reshape((-1,w.shape[0])),w)+b)
##交叉熵函数 采用真实标签的预测概率的负对数似然
def jiaocha(y_hat,y):
return -torch.log(y_hat[range(len(y_hat)),y])
##将预测类别与真实Y元素相比较
def accurate(y_hat,y):
if len(y_hat.shape)>1 and y_hat.shape[1]>1:
y_hat=y_hat.argmax(axis=1)#将每一行最大的元素下标存储起来
cmp = y_hat.type(y.dtype) ==y
return float(cmp.type(y.dtype).sum())
##评估模型在net上的准确率
def evaluate(net,data_iter):
if isinstance(net,torch.nn.Module):#可以判断一个变量的类型
net.eval()#评估模式
zhenque = zzhenque(2)#正确预测数,预测总数
for x,y in data_iter:
zhenque.add(accurate(net(x),y),y.numel())
return zhenque[0]/zhenque[1]
#构造函数__init__(self, n)在初始化类实例时,使用长度为n的列表self.data存储n个浮点零。
#方法add(self, *args)累加作为参数传递给n个变量的值,
# 通过使用zip()函数和列表推导式将它们添加到self.data中的当前值中。
# *args是一种特殊语法,用于将变量数量的参数传递给函数。
#方法reset(self)将self.data的值重置为n个浮点零的列表。
#方法__getitem__(self, idx)是Python中的一个特殊方法,
# 允许类的实例像列表或字典一样被访问。在这个实现中,它返回self.data在索引idx处的值。
class zzhenque:
"累加"
def __init__(self,n):
self.data = [0.0]*n
def add(self,*args):
self.data = [a+float(b) for a,b in zip(self.data,args)]
def reset(self):
self.data = [0.0]*len(self.data)
def __getitem__(self, idx):
return self.data[idx]
##回归训练
def tarin_zhouqi(net,train_data,loss,updater):
if isinstance(net,torch.nn.Module):
net.train()
meric = zzhenque(3)
for x,y in train_data:
y_hat = net(x)
l=loss(y_hat,y)
if isinstance(updater,torch.optim.Optimizer):
updater.zero_grad()
l.mean().backward()
updater.step()
else:
l.sum().backward()
updater(x.shape[0])
meric.add(float(l.sum()),accurate(y_hat,y),y.numel())
return meric[0]/meric[2],meric[1]/meric[2]#计算损失 计算准确率
##在动画中绘制数据
class Animator: #@save
"""在动画中绘制数据"""
def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
figsize=(3.5, 2.5)):
# 增量地绘制多条线
if legend is None:
legend = []
d2l.use_svg_display()
self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
if nrows * ncols == 1:
self.axes = [self.axes, ]
# 使用lambda函数捕获参数
self.config_axes = lambda: d2l.set_axes(
self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
self.X, self.Y, self.fmts = None, None, fmts
def add(self, x, y):
# 向图表中添加多个数据点
if not hasattr(y, "__len__"):
y = [y]
n = len(y)
if not hasattr(x, "__len__"):
x = [x] * n
if not self.X:
self.X = [[] for _ in range(n)]
if not self.Y:
self.Y = [[] for _ in range(n)]
for i, (a, b) in enumerate(zip(x, y)):
if a is not None and b is not None:
self.X[i].append(a)
self.Y[i].append(b)
self.axes[0].cla()
for x, y, fmt in zip(self.X, self.Y, self.fmts):
self.axes[0].plot(x, y, fmt)
self.config_axes()
plt.draw()
plt.pause(0.001)
display.display(self.fig)
display.clear_output(wait=True)
def train_3(net,train_data,test_data,loss,saomiaocishu,updater):
donghua = Animator(xlabel='cishu',xlim=[1,saomiaocishu],ylim=[0.3,0.9],
legend=['train loss','train acc','test acc'])
for cishu in range(saomiaocishu):
train_t =tarin_zhouqi(net,train_data,loss,updater)
test_acc = evaluate(net,test_data)
donghua.add(cishu+1,train_t+(test_acc,))
train_loss,train_acc = train_t
xuexilv = 0.1
def updater(batch):
return d2l.sgd([w,b],xuexilv,batch)
saomiaocishu = 10
# train_3(net,train,test,jiaocha,saomiaocishu,updater)
#预测标签
# def predict(net,test,n=6):
# for x,y in test:
# break
# trues = d2l.get_fashion_mnist_labels(y)
# preds = d2l.get_fashion_mnist_labels(net(x).argmax(axis=1))
# titles = [true+'\\n'+pred for true,pred in zip(trues,preds)]
# d2l.show_images(x[0:n].reshape((n,28,28)),1,n,titles = titles[0:n])
# predict(net,test)
# d2l.plt.show()
def predict_ch3(net, test_iter, n=6): #@save
"""预测标签(定义见第3章)"""
for X, y in test_iter:
break
trues = d2l.get_fashion_mnist_labels(y)
preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
titles = [true +'\\n' + pred for true, pred in zip(trues, preds)]
d2l.show_images(
X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
soft的简洁实现
import torch
from torch import nn
from d2l import torch as d2l
batch =256
train,test = d2l.load_data_fashion_mnist(batch)
#我们在线性层前定义了展平层(flatten) 变为2维
net = nn.Sequential(nn.Flatten(),nn.Linear(784,10))
# 我们仍然以均值0和标准差0.01随机初始化权重
#<https://blog.csdn.net/dss_dssssd/article/details/83959474?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522168129715616800226550747%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=168129715616800226550747&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_positive~default-1-83959474-null-null.142^v83^pc_search_v2,201^v4^add_ask,239^v2^insert_chatgpt&utm_term=nn.init.normal_&spm=1018.2226.3001.4187>
def shenchen_w(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight,std=0.01)
net.apply(shenchen_w)
loss = nn.CrossEntropyLoss(reduction='none')
#网络中的parameters包含的参数有网络连接权重W和偏置bias
train1 = torch.optim.SGD(net.parameters(),lr=0.1)
sunliancishu = 10
d2l.train_ch3(net,train,test,loss,sunliancishu,train1)
d2l.plt.show()