深度学习
chenjiale5
这个作者很懒,什么都没留下…
展开
-
train函数
每次训练都测试def get_acc(output,label): total = output.shape[0] _,pred_label = output.max(1) return (pred_label == label).sum().data.item()/totaldef train(net,train_data,valid_data,num_epochs...原创 2019-08-05 16:58:33 · 4320 阅读 · 0 评论 -
优化算法1:随机梯度下降法
梯度下降公式很简单import numpy as npimport torchfrom torchvision.datasets import MNISTfrom torch.utils.data import DataLoaderfrom torch import nnfrom torch.autograd import Variableimport timeimport ma...原创 2019-08-01 16:48:54 · 407 阅读 · 0 评论 -
查看下载数据集中的图片
import matplotlib.pyplot as plt下载数据集,并将数据集变成DataLoader类型train_set = CIFAR10('./data', train=True, transform=data_tf,download=True)train_data = torch.utils.data.DataLoader(train_set, batch_size=64,...原创 2019-08-05 20:42:08 · 306 阅读 · 0 评论 -
构建一个简单的VGG网络
import torchfrom torch.autograd import Variablefrom torch import nndef vgg_block(num_convs, input_channels, output_channels): net = [ nn.Conv2d(input_channels, output_channels, kernel_...原创 2019-08-05 21:31:04 · 664 阅读 · 0 评论 -
优化算法2--动量法
起始点a到b点的梯度下降记为n1b点到c点的实际梯度由两方面组成v (t) =γv(t-1) +α∇b+前面是n1的γ倍,通常取0.9(为什么<1?因为要减小早期的梯度影响)+后面α是学习率,∇b是b点理论上的梯度下降有很多点之后,v(t-1)指的是之前所有步骤累加的动量和from torch.utils.data import DataLoaderfrom torch imp...原创 2019-08-01 19:44:55 · 1072 阅读 · 0 评论 -
优化算法3,4--Adagrad算法与RMSProp算法
思想:如果一个参数的梯度一直都非常大,就让它的学习率变小一点,防止震动,反之,则让其学习率变大,使其能更快更新做法:学习率由下列式子所得+后的参数是为了防止分母等于0,一般取10的-10次方对每个参数,初始化一个变量s=0,每次参数更新时,将梯度平方求和累加到s上所以梯度越大,累加得s越大,学习率越小缺点:到后期,分母越来越大,学习率会变得较小,无法较好的收敛from torch.u...原创 2019-08-01 20:09:51 · 1351 阅读 · 0 评论 -
优化算法5:--Adadelta算法
Adadelta算法是Adagrad算法的延伸,与RMSProp算法一样,是为了解决Adagrad中学习率不断减小的问题,RMSProp是通过移动加权平均的方式,Adadelta也一样,并且Adadelta不需要学习率这个参数RMSProp算法Adadelta的分母和RMSProp的分母一致需要更新参数的变化量为分子表示的是每次更新梯度变化量的累加量最后的参数更新如下实现opt...原创 2019-08-01 21:18:59 · 4653 阅读 · 0 评论 -
简单GoogLeNet实现
import numpy as npimport torchfrom torch import nnfrom torch.autograd import Variabledef conv_relu(in_channel, out_channel, kernel, stride=1, padding=0): layer = nn.Sequential( nn.Con...原创 2019-08-06 10:45:32 · 386 阅读 · 0 评论 -
读取并预处理自己数据集的一种方式
引入必要的包from torchvision.datasets import ImageFolderimport matplotlib.pyplot as pltfrom torchvision import transforms as tfsfrom torch.utils.data import DataLoaderimport numpy as npimport torch加...原创 2019-08-07 09:32:25 · 852 阅读 · 0 评论