pytorch学习
江南汪
这个作者很懒,什么都没留下…
展开
-
pytorch先建立列表再转化为ModuleList
在代码中我们经常看到先将网络结构添加到列表中,再转化为ModuleList类型代码示例:import torchimport torch.nn as nn#建立列表convs=[]l=nn.Conv2d(3,3,kernel_size=1)convs.append(l)#转化为ModuleList类型module=nn.ModuleList(convs)print(list(module.modules()))输出:[ModuleList( (0): Conv2d(3, 3,原创 2021-05-23 14:35:01 · 875 阅读 · 0 评论 -
torch.exp用法
torch.exp(input):代码示例:import torchimport torch.nn as nnimport matha=torch.tensor([0,math.log(2)])print(torch.exp(a))输出:tensor([1., 2.])原创 2021-05-16 21:16:46 · 10841 阅读 · 0 评论 -
torch.init.normal_和torch.init.constant_用法
torch.init.normal_:给tensor初始化,一般是给网络中参数weight初始化,初始化参数值符合正态分布。torch.init.normal_(tensor,mean=,std=) ,mean:均值,std:正态分布的标准差代码示例:import torchimport torch.nn as nnl=nn.Conv2d(2,2,kernel_size=1)a=l.weightprint("a:",a)b=nn.init.normal_(l.weight,mean=0,st原创 2021-05-16 20:49:09 · 7297 阅读 · 1 评论 -
pytorch中.modules用法
.modules:返回网络中所有网络信息,如卷积、线性等,还有索引。代码示例:import torchimport torch.nn as nnl=nn.Sequential(nn.Conv2d(16,16,kernel_size=1), nn.Conv2d(8,8,kernel_size=1))print(list(l.modules()输出:[Sequential( (0): Conv2d(16, 16, kernel_size=(1, 1), str原创 2021-05-16 20:08:14 · 1263 阅读 · 0 评论 -
torch.nn.GroupNorm用法
torch.nn.GroupNorm:将channel切分成许多组进行归一化torch.nn.GroupNorm(num_groups,num_channels)num_groups:组数num_channels:通道数量代码示例:a=torch.randn(15,256,9,15)#将channel256分为8组,每组32channelm=nn.GroupNorm(8,256)...原创 2021-05-16 17:16:56 · 14669 阅读 · 0 评论 -
torch.nn.functional.interpolate用法
torch.nn.functional.interpolate:将图片上/下采样到指定的大小torch.nn.functional.interpolate(input,size=)(本次示例输入是4维)输入的维度:batch_size×channels×height×widthsize:output size代码示例:import torchimport torch.nn.functional as Fa=torch.randn(2,3,4,5)b=F.interpolate(a,2)p原创 2021-05-15 11:51:50 · 6233 阅读 · 0 评论 -
待定
行:sample x[:,:-1]列:feature x[:,-1]原创 2021-05-13 20:30:28 · 120 阅读 · 0 评论 -
torch.nn.Module.parameters
torch.nn.Module.parameters:计算模型的参数,返回一个关于模型参数的迭代器。代码示例:1.建立一个线性模型2.打印参数import torchimport torch.nn as nnclass mymodule(nn.Module): def __init__(self): super(mymodule,self).__init__() self.linear=nn.Linear(2,3) self.relu=nn原创 2021-05-12 21:39:50 · 3313 阅读 · 0 评论 -
RuntimeError: “rsqrt_cpu“ not implemented for ‘Long‘
在运行torch.rsqrt函数时候报错:import torcha=torch.tensor([1,2,3,4])b=torch.rsqrt(a)print("b:",b)报错:Traceback (most recent call last): File "D:/pytorch学习代码/torch.rsqrt.py", line 8, in <module> b=torch.rsqrt(a)RuntimeError: "rsqrt_cpu" not implem原创 2021-05-12 18:37:04 · 1016 阅读 · 0 评论 -
torch.rsqrt用法
torch.rsqrt:代码示例:import torcha=torch.tensor([1.,2.,3.,4.])print("a:",a)b=torch.rsqrt(a)print("b:",b)输出:a: tensor([1., 2., 3., 4.])b: tensor([1.0000, 0.7071, 0.5774, 0.5000])原创 2021-05-12 16:17:52 · 1409 阅读 · 0 评论 -
torch.nn.ReLU用法
torch.nn.ReLU:调用relu函数代码示例来自pytorch官方文档,正好复习下unsqueeze和cat函数的使用:import torchimport torch.nn as nnm=nn.ReLU()input=torch.randn(2).unsqueeze(0)print("input:",input)print("input的shape:",input.shape)output=torch.cat((m(input),m(-input)),0)print("outpu原创 2021-05-11 12:35:33 · 10321 阅读 · 0 评论 -
torch.randn用法
torch.randn:用来生成随机数字的tensor,这些随机数字满足标准正态分布(0~1)。torch.randn(size),size可以是一个整数,也可以是一个元组。代码示例:import torcha=torch.randn(3)b=torch.randn(3,4)print("a:",a)print("b:",b)输出:a: tensor([ 0.9405, -0.1068, 0.1712])b: tensor([[-1.0962, -0.1893, 1.2323,原创 2021-05-10 21:42:16 · 74991 阅读 · 0 评论 -
torch.where用法
torch.where:类似于if语句:if condition:xelse:y但在tensor中有它的特殊之处torch.where(condition,x,y)代码示例:搬自pytorch英文文档>>> x = torch.randn(3, 2)>>> y = torch.ones(3, 2)>>> xtensor([[-0.4620, 0.3139], [ 0.3898, -0.7197],原创 2021-05-10 21:31:26 · 2048 阅读 · 0 评论 -
torch.transpose用法
torch.transpose:转置经常在矩阵中使用,交换两个维度。torch.transpose(tensor,dim_0,dim_1)代码示例:import torcha=torch.tensor([[[1,2,3],[4,5,6]], [[7,8,9],[10,11,12]]])b=torch.transpose(a,1,2)print("tensor_a",a)print("tensor_b",b)print("a的shape:",a.shape)p原创 2021-05-10 21:16:01 · 11265 阅读 · 5 评论 -
torch.squeeze用法
torch.squeeze:将tensor中大小为1的维度删除torch.squeeze(tensor,dim)代码示例:import torcha=torch.ones(2,2,2,1,1)b=torch.squeeze(a)c=torch.squeeze(a,0)d=torch.squeeze(a,3)print("a的shape:",a.shape)print("b的shape:",b.shape)print("c的shape:",c.shape)print("d的shape:"原创 2021-05-10 16:50:09 · 14077 阅读 · 0 评论 -
torch.split用法
torch.split,用来划分tensor,可以从数量上划分,还有维度上划分。torch.split(tensor,split_szie,dim),split_size有整数,也有列表,dim默认为0,自己也可以修改。代码示例:import torcha=torch.tensor([[[1,2,3],[4,5,6]], [[7,8,9],[10,11,12]]])print("a的shape:",a.shape)#在第0维上进行splitb=torch.spl原创 2021-05-10 16:29:07 · 8985 阅读 · 0 评论 -
torch.reshape用法
torch.reshape用来改变tensor的shape。torch.reshape(tensor,shape)import torcha=torch.tensor([[[1,2,3],[4,5,6]], [[7,8,9],[10,11,12]]])print("a的shape:",a.shape)b=torch.reshape(a,((4,3,1)))print("b:",b)print("b的shape:",b.shape)输出:a的shape:原创 2021-05-10 15:36:03 · 55629 阅读 · 3 评论 -
torch.index_select用法
torch.index_select:通过选择索引然后去得到想要的tensor,针对比较长的tensortorch.index_select(tensor,维度,选择的index)代码示例:import torch#shape为(2,2,3)a=torch.tensor([[[1,2,3],[4,5,6]], [[7,8,9],[10,11,12]]])#选择索引0和索引2的tensor indices=torch.tensor([原创 2021-05-10 09:05:54 · 10181 阅读 · 0 评论 -
torch.chunk用法
torch.chunk:用来将tensor分成很多个块,简而言之我理解的就是切分吧,可以在不同维度上切分。torch.chunk(tensor,chunk数,维度)代码示例:import torcha=torch.tensor([[[1,2],[3,4]], [[5,6],[7,8]]])b=torch.chunk(a,2,1)print(a)print(b)输出:tensor([[[1, 2], [3, 4]], [[5,原创 2021-05-10 08:28:27 · 16201 阅读 · 0 评论 -
torch.cat()用法
torch.cat:想要去对那哪一个维度进行concat就必须要保证其他维度的大小是一样的。如a的shape为(2,2,3),b的shape为(3,2,3),那么在维度0上是可以进行concat的,concat过后得到(5,2,3)而在维度1和2维度2上是不可以进行concat的。代码如下:import torcha=torch.tensor([[[1,2,3],[4,5,6]], [[7,8,9],[10,11,12]]])print("a的shape为:",a原创 2021-05-09 20:47:43 · 2470 阅读 · 0 评论 -
torch.arrange和torch.range用法
torch.arrange:原创 2021-05-09 16:48:08 · 4789 阅读 · 5 评论 -
torch.zeros和torch.ones
torch.zeros:用来将tensor中元素值全置为0a=torch.zeros(2,2)''tensor([[0., 0.], [0., 0.]])''b=torch.zeros(3)''tensor([0., 0., 0.])''torch.ones:用来将tensor中元素值全置为1c=torch.ones(2,2)''tensor([[1., 1.], [1., 1.]])''...原创 2021-05-09 16:26:06 · 3599 阅读 · 0 评论 -
torch.from_numpy()用法
torch.from_numpy()用来将数组array转换为张量Tensora=np.array([1,2,3,4])print(a)#[1 2 3 4]print(torch.from_numpy(a))#tensor([1, 2, 3, 4], dtype=torch.int32)原创 2021-05-09 15:57:31 · 21859 阅读 · 3 评论 -
torch.numel()用法
torch.numel()用来统计tensor中元素的个数#输入a=torch.tensor([1,2,3,4])print(torch.numel(a))#输出4原创 2021-05-09 15:39:38 · 4465 阅读 · 0 评论 -
pytorch中判断object是否是Tensor
import torcha=torch.tensor([1,2,3,4])if not isinstance(a,torch.Tensor): raise TypeError('wrong')在很多代码都能遇到,使用isinstance函数判断原创 2021-05-09 15:29:53 · 1354 阅读 · 0 评论