RNN
Mr_Hello_World
菜鸟的进阶之路
展开
-
nn.Embedding参数说明
CLASStorch.nn.Embedding(num_embeddings: int,embedding_dim: int,padding_idx: Optional[int] = None,max_norm: Optional[float] = None,norm_type: float = 2.0,scale_grad_by_freq: bool = False,sparse: bool = False,_weight: Optional[torch.Tensor] = None)[SO...原创 2020-10-23 10:15:56 · 4903 阅读 · 0 评论 -
PyTorch 使用RNN实现MNIST手写字体识别
此处使用普通的RNN 推荐一个RNN入门资料:https://zhuanlan.zhihu.com/p/28054589 28*28的图片,每个输入序列长度(seq_len)为28,每个输入的元素维度(input_size)为28,将一张图片的分为28列,为长度28的序列,序列中每个元素为28个元素(即每一列的像素)。 注意,如果batch_first设置为1,则输出维度out: bat...原创 2019-01-24 23:50:36 · 3976 阅读 · 9 评论 -
PyTorch 使用CNN实现MNIST手写字体识别
一个epoch下来,Test ACC: 0.9892 import torch import torch.nn as nn import torchvision from torchvision import datasets,transforms from torch.autograd import Variable from matplotlib import pyplot as plt...原创 2019-01-25 09:08:32 · 1281 阅读 · 0 评论 -
PyTorch 使用 RNN 输入sin(x)值序列,预测最后一个sin(x)对应的cos(x)
每一个sin(x)可能对应两个cos(x)的数值,不是一一对应关系,这边借助于RNN,输入sin(x)序列,预测sin(x)序列中最后一个值,对应的cos(x),借助的是RNN的记忆性。 这边要注意的是各个Tensor的维度的处理,熟悉torch.stack,torch.cat,x.squeeze(index), x.unsqueeze(index) 函数的使用 import ...原创 2019-01-25 11:51:04 · 1253 阅读 · 3 评论 -
pytorch RNN层api的几个参数说明(转载)
classtorch.nn.RNN(*args, **kwargs) input_size – The number of expected features in the input x hidden_size – The number of features in the hidden state h num_layers – Number of recurrent layers. E....转载 2019-01-23 11:38:44 · 1638 阅读 · 0 评论