Pytorch-7:The Train Process 和 Comfusion Matrix

网络、数据集的准备

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)

train_set = torchvision.datasets.FashionMNIST(
	root='./data/FashionMNIST'
	,train=True
	,download=True
	,transform=transforms.Compose([transforms.ToTensor()])
)

class Network(nn.Module):
	def __init__(self):
		super(Network, self).__init()##do not forget
		self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
		self.conv1 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
		
		self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
		self.fc2 = nn.Linear(in_features=120, out_features=60)
		self.out = nn.Linear(in_features=120, out_features=10)
		
	def forward(self, t):
		#input layer
		t=t
		
		#hidden conv layer
		t = sef.conv1(t)
		t = F.relu(t)
		t = F.max_pool2d(t, kernel_size=2, stride=2)
	
		#hidden conv layer
		t = sef.conv2(t)
		t = F.relu(t)
		t = F.max_pool2d(t, kernel_size=2, stride=2)
	
		#hidden linear layer
		t = t.reshape(-1, 12 * 4 * 4)
		t = self.fc1(t)
		t = F.relu(t)
	
		#hidden linear layer
		t = self.fc2(t)
		t = F.relu(t)
	
		#output layer
		t = self.out(y)
		#t = F.softmax(t) Softmax会在计算crossentropy时自动实现,不在输出层中实施
		return t

network = Network()
以 batch_size 为10创建 data_loader 并获得一个 batch 的数据

data_loader 的 shape 是 [batch size, input channels, height, width]

data_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
batch = next(iter(data_loader))

# 获得预测正确的label数
def get_num_correct(pred, labels):
	return preds.argmax(dim=1).eq(labels).sum().item()

The Train Process 训练网络

  • 将一个 batch 送入网络中
  • 计算 loss
  • 计算 loss 相对 learnable weights 的梯度
  • 根据梯度更新 weights
  • 重复以上步骤以完成一个epoch
  • 完成多个epoch以获得目标精确度
整个训练过程(多个epoch)
##### 所有 epoch
for epoch in range(5):
	total_loss = 0
	total_correct = 0
	
	##### 一个epoch,所有batch
	for batch in train_loader:
		images, labels = batch
		##### 给网络以输入一个 batch,并获得输出
		preds = network(images)	
		##### 计算损失
		loss = F.cross_entropy(preds, labels)	
		
		##### 计算梯度
		loss.backward()
		
		##### 更新权重
		optimizer = optim.Adam(network.parameters(), lr=0.001)
		optimizer.step()# 更新权重

		total_loss += loss.item()
		total_correct += get_num_correct(preds, labels)
	
	print("epoch:",epoch,"total_loss:",total_loss,"total_correct:",total_correct)

Comfusion Matrix, 误差矩阵/混淆矩阵

Comfusion Matrix,混淆矩阵也称误差矩阵,是表示精度评价的一种标准格式,用n行n列的矩阵形式来表示。具体评价指标有总体精度、制图精度、用户精度等,这些精度指标从不同的侧面反映了图像分类的精度。

关闭计算图 graph

使用训练完成的网络时,不需要计算图来跟踪梯度,此时关闭计算图,能够加快运算,节省内存。

全局关闭:torch.set_grad_enabled(False)
局部关闭:with torch.no_grad():或者使用函数装饰器的方法@torch.no_grad()

Concatenate 和 Stack(torch.stack 和 torch.cat)

Concatenate 的作用是在一个已存的维度上连接张量,dim 参数决定在第几维度拼接;Stack的作用是在一个新的维度上连接张量,dim 参数表示这个用于拼接的新维度第几维度。
Stack 可以由组合 unsqueeze 和 concatenate 实现。

### dim=0
t1 = torch.tensor([1,1,1])
t2 = torch.tensor([2,2,2])
t3 = torch.tensor([3,3,3])

torch.cat((t1,t2,t3), dim=0) #-> tensor([1, 1, 1, 2, 2, 2, 3, 3, 3])
torch.stack((t1,t2,t3), dim=0) #->  tensor([[1, 1, 1],
								#			[2, 2, 2],
								#			[3, 3, 3]])

torch.cat((t1.unsqueeze(0),t2.unsqueeze(0),t3.unsqueeze(0)), dim=0)# 组合实现stack
#-> #	tensor([[1, 1, 1],
	#	        [2, 2, 2],
	#	        [3, 3, 3]])

### dim=1
torch.stack(t1,t2,t3), dim=1) #->tensor([[1, 2, 3],
								#   	 [1, 2, 3],
								#   	 [1, 2, 3]])
torch.cat((t1.unsqueeze(1),t2.unsqueeze(1),t3.unsqueeze(1)), dim=1)# 组合实现stack
#->tensor([[1, 2, 3],
#        [1, 2, 3],
#        [1, 2, 3]])

获取Confusion Matrix
# 获取预测值
def get_a_preds(model, loader):
	all_preds = torch.tensor([])
	for batch in loader:
		imgs, labels = batch
		preds = network(imgs)
		all_preds = torch.cat((all_preds, preds), dim=0)
	return all_preds

with torch.no_grad():
	prediction_loader = torch.utils.data.Dataloader(train_set, batch_size=10000)
	train_preds = get_all_preds(network, prediction_loader)

# 获取 Comfusion Matrix
stacked = torch.stack((train_set.targets, train_preds.argmax(dim=1)),dim=1) # torch.Size([60000, 2])
cmt = torch.zeros(10, 10, dtype=torch.int32) # Comfusion Matrix

for p in stacked:
	tl, pl = p.tolost()
	cmt[tl, pl] = cmt[tl, pl] + 1

得到混淆矩阵为:

tensor([
    [5637,    3,   96,   75,   20,   10,   86,    0,   73,    0],
    [  40, 5843,    3,   75,   16,    8,    5,    0,   10,    0],
    [  87,    4, 4500,   70, 1069,    8,  156,    0,  106,    0],
    [ 339,   61,   19, 5269,  203,   10,   72,    2,   25,    0],
    [  23,    9,  263,  209, 5217,    2,  238,    0,   39,    0],
    [   0,    0,    0,    1,    0, 5604,    0,  333,   13,   49],
    [1827,    7,  716,  104,  792,    3, 2370,    0,  181,    0],
    [   0,    0,    0,    0,    0,   22,    0, 5867,    4,  107],
    [  32,    1,   13,   15,   19,    5,   17,   11, 5887,    0],
    [   0,    0,    0,    0,    0,   28,    0,  234,    6, 5732]
])

也可以用 sklearn 来计算混淆矩阵:

from sklearn.metrics import confusion_matrix

cmt = confusion_matrix(train_set.targets, train_preds.argmax(dim=1))
画出 Confusion Matrix
# helper function
import itertools
import numpy as np
import matplotlib.pyplot as plt

def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

# plot
names = (
    'T-shirt/top'
    ,'Trouser'
    ,'Pullover'
    ,'Dress'
    ,'Coat'
    ,'Sandal'
    ,'Shirt'
    ,'Sneaker'
    ,'Bag'
    ,'Ankle boot'
)
plt.figure(figsize=(10,10))
plot_confusion_matrix(cmt, names)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值