动手学机器学习笔记 LeNet

LeNet 结构:

# 步幅s 上/下/左/右填充为p 数据为n 窗口大小为k
# 输出形状 (n-k+2*p+s)/s
net = nn.Sequential(
    nn.Conv2d(1,6,kernel_size=5, padding=2),nn.Sigmoid(), # (28-5+2*2+1)/1=28 输入数据为1x1x28x28 因此输出张量形状为1x6x30x30
    nn.AvgPool2d(kernel_size=2,stride=2),     #(28-2+2)/2=14 输入数据为1x6x28x28 因此输出张量形状为1x6x14x14,kernel_size和stride都为2的化输入输出减半
    nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),# (14-5+1)/1=10 Input:1x6x14x14 output 1x16x10x10
    nn.AvgPool2d(kernel_size=2, stride=2), #10/2=5 Input:1x16x10x10 output: 1x16x5x5
    nn.Flatten(),       #扁平化 1x16*5*5
    nn.Linear(16*5*5,120),nn.Sigmoid(),
    nn.Linear(120,84), nn.Sigmoid(),
    nn.Linear(84,10))

isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()。

isinstance() 与 type() 区别:

type() 不会认为子类是一种父类类型,不考虑继承关系。

isinstance() 会认为子类是一种父类类型,考虑继承关系。

如果要判断两个类型是否相同推荐使用 isinstance()。

def evaluate_accuracy(net, date_iter,device=None):
    """" 计算数据集上模型的精度 """
    if isinstance(net, nn.Module):  #net 是否是nn.Module 的类型或者子类
        net.eval()   #将模型设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device

    metric = d2l.Accumulator(2)   #正确预测数,预测总数
    with torch.no_grad():
        for X, y in date_iter:
            if isinstance(X, list):
                X=[x.to(device) for x in X]
            else:
                X = X.to(device)
            y= y.to(device)
            metric.add(d2l.accuracy(net(X)), y.numel())
    return metric[0]/metric[1]

注accuracy 的实现源码:

def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    """给定预测概率分布y_hat,当我们必须输出硬预测(hard prediction)时, 我们通常选择预测概率最高的类。
	当预测与标签分类y一致时,即是正确的。 分类精度即正确预测数量与总预测数量之比。 
	为了计算精度,我们执行以下操作。 首先,如果y_hat是矩阵,那么假定第二个维度存储每个类的预测分数。 
	我们使用argmax获得每行中最大元素的索引来获得预测类别。 然后我们将预测类别与真实y元素进行比较。 
	由于等式运算符“==”对数据类型很敏感, 因此我们将y_hat的数据类型转换为与y的数据类型一致。 
	结果是一个包含0(错)和1(对)的张量。 最后,我们求和会得到正确预测的数量。"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())
    

训练函数:

def train_ch6(net, train_iter,test_iter,num_epochs, lr, device):
    """使用GPU训练模型"""
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print("training on", device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(),lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    timer ,num_batches =d2l.Timer(), len(train_iter)
    for epoch in range(num_batches):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i ,(X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X,y =X.to(device),y.to(device)
            y_hat = net(X)
            l = loss(y_hat,y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l*X.shape[0],d2l.accuracy(y_hat,y),X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc =metric[1] / metric[2]
            if ( i + 1 )% (num_batches // 5) ==0 or i == num_batches -1:
                animator.add(epoch+ (i+1)/ num_epochs ,(train_l,train_acc,None))
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch+1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

开始训练:

lr, num_epochs = 0.9, 10  #学习率0.9 迭代次数10

train_ch6(net, train_iter, test_iter ,num_epochs ,lr,d2l.try_gpu())
d2l.plt.show()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值