-
Pytorch安装
推荐使用Anaconda来管理pytorch等python包
由于电脑配置比较老了,原来装的是windows系统,安装的Ubuntu虚拟机比较慢,所以只好在windows下来安装环境了。使用默认的anaconda源经常会出现CondaHTTPError以及无法创建虚拟环境的问题,需要把anaconda源改为国内的清华大学镜像就可以了。可以参考[Anaconda 镜像使用帮助](https://mirror.tuna.tsinghua.edu.cn/help/anaconda/),修改.condarc文件内容为:
channels:
- defaults
show_channel_urls: true
channel_alias: https://mirrors.tuna.tsinghua.edu.cn/anaconda
default_channels:
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2
custom_channels:
conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
msys2: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
bioconda: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
menpo: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
simpleitk: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
然后打开Anaconda Prompt,创建虚拟环境:
conda create -n pytorch python=3.7
激活虚拟环境:
conda activate pytorch
通过pip安装pytorch包,可以参考pytorch官方网站根据自己的系统和软硬件环境选择合适的版本(由于我的显卡不支持CUDA,所以安装的是CPU only的版本):
pip install torch==1.2.0+cpu torchvision==0.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
然后可以退出虚拟环境:
deactivate
至此,已经成功创建了另一个conda虚拟环境pytorch,并安装了pytroch。pycharm的安装与使用不在此赘述。
可以用以下简单的代码来打印出安装的pytorch的版本:
import torch
print("hello pytorch {}".format(torch.__version__))
Pytorch的基础数据结构Tensor和Variable
Tensor张量是什么? 可以把Tensor看作是一个多维数组,是标量(0维张量),向量(1维张量),矩阵(2维张量)的高维拓展。
Variable是torch.autograd中的数据类型,主要用于封装Tensor,进行自动求导。
data:被包装的Tensor
grad:data的梯度
grad_fn:创建Tensor的Function,是自动求导的关键
requires_grad:指示是否需要梯度
is_leaf:指示是否是叶子结点(张量)
张量的创建
-
直接创建
- 通过torch.tensor()创建
data:可以是list, numpy
dtype:数据类型,默认与data的类型一致
device:所在设备,cuda/cpu
requires_grad:是否需要梯度
import torch
import numpy as np
arr = np.ones((3, 3))
print("ndarray type: ", arr.dtype)
t = torch.tensor(arr)
print(t)
结果如下:
ndarray type: float64
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], dtype=torch.float64)
- 通过torch.from_numpy创建
arr = np.array([[1, 2, 3], [4, 5, 6]])
t = torch.from_numpy(arr)
print(t)
print("array address:", id(arr)) ## show the memory address
print("tensor data address: ", id(t.data))
结果如下:
tensor([[1, 2, 3],
[4, 5, 6]], dtype=torch.int32)
array address: 2796001828096
tensor data address: 2796001869640
可以看出,通过torch.from_numpy创建的tensor和原ndarray共享内存,当修改其中一个的数据,另一个也同时被改动。
-
依据数值创建
t = torch.zeros((3, 3)) # 创建全0张量
"""
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
"""
t = torch.ones((3, 3)) # 创建全1张量
"""
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
"""
t = torch.full((3, 4), 5) # 创建值全为5的3X3的张
"""
tensor([[5., 5., 5., 5.],
[5., 5., 5., 5.],
[5., 5., 5., 5.]])
"""
# 根据提供的size创建全0,全1,全为fill_value的张量
arr = np.array([[1, 2, 3], [4, 5, 6]])
test_t = torch.from_numpy(arr)
t = torch.zeros_like(test_t)
"""
tensor([[0, 0, 0],
[0, 0, 0]], dtype=torch.int32)
"""
t = torch.ones_like(test_t)
"""
tensor([[1, 1, 1],
[1, 1, 1]], dtype=torch.int32)
"""
t = torch.full_like(test_t, 3)
"""
tensor([[3, 3, 3],
[3, 3, 3]], dtype=torch.int32)
"""
# 创建等差的1维张量
t = torch.arange(2, 10, 2)
### output
## tensor([2, 4, 6, 8])
# 创建均分的1维张量
t = torch.linspace(1, 3, steps=10)
"""
tensor([1.0000, 1.2222, 1.4444, 1.6667, 1.8889, 2.1111, 2.3333, 2.5556, 2.7778,
3.0000])
"""
# 创建对数均分的1维张量
t = torch.logspace(1, 2, steps=8)
"""
tensor([ 10.0000, 13.8950, 19.3070, 26.8270, 37.2759, 51.7948, 71.9686,
100.0000])
"""
# 创建单位对角矩阵
t = torch.eye(3, 3)
"""
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
"""
-
依据概率分布创建
# 生成正态分布(高斯分布)
t_normal = torch.normal(0., 1., size=(4,))
## tensor([ 0.0206, -0.4805, -0.3185, -1.0287])
# 生成标准正态分布
t = torch.randn((2, 2))
"""
tensor([[ 0.4790, 0.1725],
[-2.1356, -0.2725]])
"""
# 在区间[0, 1)上生成均匀分布
t = torch.rand((2, 2))
"""
tensor([[0.8578, 0.9063],
[0.9365, 0.7199]])
"""
# 在区间[low, high)生成整数均匀分布
t = torch.randint(1, 5, (2, 2))
"""
tensor([[3, 2],
[1, 4]])
"""
# 生成从0到n-1的随机排列
t = torch.randperm(10)
# tensor([1, 9, 3, 2, 6, 4, 5, 0, 7, 8])