开始使用 PyTorch 的 5 个基本 Funtions

欢迎关注 “小白玩转Python”,发现更多 “有趣”

Function 1 — torch.device()

PyTorch 是 Facebook 开发的一个开源库,在数据科学家中非常流行。其崛起的主要原因之一是 GPU 对开发者的内置支持。

torch.device 允许您指定负责将张量加载到内存中的设备类型,该函数需要一个指定设备类型的字符串参数。

你甚至可以传递一个序号,比如设备索引。或者不指定 PyTorch 使用当前可用的设备。

# Example 1 - working 
torch.device('cuda:1')
device(type='cuda', index=1)

我们选择要在运行时存储张量的设备类型。注意,我们已经指定了我们的设备类型为 cuda,并且将序号连接到由‘ :’分隔的相同字符串中。

# Example 2 - working 
device = torch.device('cuda', 1)
device
device(type='cuda', index=1)

代码 2 中可以看出:通过为设备类型和索引传入单独的参数来实现相同的结果。

t1 = torch.tensor(2.0, device=torch.device('cpu'))

torch.device() 中预期的设备类型是 cpu、 cuda、 mkldnn、 opengl、 opencl、 ideep、 hip 和 msnpu。为了正确使用此方法,设备类型应该存在于预期设备列表中。

让我们看看当我们尝试将 GPU 指定为设备类型时会发生什么。

# Example 3 - breaking
torch.device('gpu', 1)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-7b7286f1daee> in <module>
      1 # Example 3 - breaking
----> 2 torch.device('gpu', 1)


RuntimeError: Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu device type at start of device string: gpu


指定的设备类型应该在运行笔记本的机器上可用。如果不这样做,将导致类似上述错误。

在下列代码中,我们使用一个需要 NVIDIA GPU 的 cuda 设备类型定义了一个张量。由于我们的机器没有任何可用的 GPU,内核抛出了一个运行时错误。

# Example 3 - breaking (to illustrate when it breaks)
torch.tensor(2.0, device=torch.device('cuda', 1))
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-8-1fabe5391fef> in <module>
      1 # Example 3 - breaking (to illustrate when it breaks)
----> 2 torch.tensor(2.0, device=torch.device('cuda', 1))


/srv/conda/envs/notebook/lib/python3.7/site-packages/torch/cuda/__init__.py in _lazy_init()
    147             raise RuntimeError(
    148                 "Cannot re-initialize CUDA in forked subprocess. " + msg)
--> 149         _check_driver()
    150         if _cudart is None:
    151             raise AssertionError(


/srv/conda/envs/notebook/lib/python3.7/site-packages/torch/cuda/__init__.py in _check_driver()
     52 Found no NVIDIA driver on your system. Please check that you
     53 have an NVIDIA GPU and installed a driver from
---> 54 http://www.nvidia.com/Download/index.aspx""")
     55         else:
     56             # TODO: directly link to the alternative bin that needs install


AssertionError: 
Found no NVIDIA driver on your system. Please check that you
have an NVIDIA GPU and installed a driver from
http://www.nvidia.com/Download/index.aspx

Function 2 — torch.view()

torch.view() 方法将张量(无论是矢量、矩阵还是标量)的视图更改为所需的形状。转换后的张量修改了表示维数,但保留了相同的数据类型。

让我们来看一个例子:

定义一个样本张量 x,并将其维数转换为另一个张量 z。

# Example 1 - working
x = torch.randn(4, 5)
x
tensor([[-1.3324,  0.4927, -0.7445,  0.0132, -1.0947],
        [ 1.4211, -1.6057,  0.9154,  1.1111,  0.8435],
        [ 0.9480, -0.2812,  1.2160,  1.1731,  1.0137],
        [ 1.8901, -1.3294,  0.5607,  0.9897, -1.3568]])

在上述示例中,我们将4 × 5矩阵的维数转换为单行向量。在默认情况下连续打印行的元素。变换后的视图中的每个行元素都以原张量中的相同顺序出现。

此外,新张量的形状必须支持原张量中相同数目的元素。你不能在一个5 x 3的视图中存储一个4 x 5形状的张量。

# Example 2 - working
x = torch.randn(4, 8)
z = x.view(-1, 16)
z.size()
torch.Size([2, 16])

如上述示例,我们将使用 -1 来表示一个维度。PyTorch 自动从其他维度解释未知维度。

# Example 3 - breaking
x.view(2, 8) # This will throw an error as the new shape will only stor 2 x 8 = 16 elements.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-42-55503410d9d7> in <module>
      1 # Example 3 - breaking
----> 2 x.view(2, 8) # This will throw an error as the new shape will only stor 2 x 8 = 16 elements.


RuntimeError: shape '[2, 8]' is invalid for input of size 32

请注意,一次只能推断出一个维度。对多个维度使用 -1只会导致运行时错误!

PyTorch 最多只允许传入一个 -1。如果传递多个 -1,它将抛出一个运行时错误。

Function 3 — torch.set_printoptions()

很多时候,您希望在执行某些任务之前打印张量的内容。为此,在打印张量时可能需要更改显示表示形式。

使用 set_printoptions,您可以调整精度级别、线宽、结果阈值等属性。

在我们的例子中,我们用一个表示20 x 30矩阵的张量,这样一个巨大的矩阵是不经常需要的。打印张量变量后面的一个通用用例是查看前几行和后几行。

t2 = torch.rand(20, 30)
t2
tensor([[0.5125, 0.5751, 0.1321, 0.0023, 0.5471, 0.0895, 0.9011, 0.9573, 0.6723,
         0.8132, 0.7488, 0.1860, 0.5075, 0.3734, 0.6644, 0.3321, 0.2505, 0.2124,
         0.3036, 0.9942, 0.7319, 0.6374, 0.1483, 0.3731, 0.8582, 0.9378, 0.7472,
         0.9595, 0.9211, 0.3213],
        [0.2082, 0.2768, 0.7095, 0.3739, 0.6506, 0.6024, 0.6385, 0.1580, 0.8068,
         0.8255, 0.2024, 0.8948, 0.4900, 0.2746, 0.1933, 0.8375, 0.9401, 0.2070,
         0.2478, 0.3216, 0.9916, 0.8356, 0.4052, 0.6393, 0.6402, 0.7690, 0.2805,
         0.6679, 0.8710, 0.0959],
        [0.3106, 0.2850, 0.7273, 0.8001, 0.4156, 0.5507, 0.1308, 0.5954, 0.2465,
         0.5703, 0.1347, 0.8934, 0.5148, 0.5169, 0.8892, 0.8679, 0.9969, 0.0386,
         0.3858, 0.0766, 0.1805, 0.2141, 0.0704, 0.6004, 0.2985, 0.6005, 0.0040,
         0.5385, 0.8494, 0.2355],
        [0.8993, 0.7823, 0.8907, 0.8764, 0.9194, 0.0349, 0.3185, 0.0500, 0.6937,
         0.5656, 0.1844, 0.1586, 0.7471, 0.8469, 0.8463, 0.3646, 0.0494, 0.7467,
         0.4865, 0.6169, 0.9768, 0.5953, 0.7254, 0.6142, 0.5560, 0.5831, 0.4925,
         0.6164, 0.1118, 0.0587],
        [0.1739, 0.0149, 0.0556, 0.1247, 0.5763, 0.4129, 0.4825, 0.6093, 0.5256,
         0.3989, 0.5695, 0.5199, 0.7978, 0.6284, 0.7769, 0.1155, 0.9871, 0.8168,
         0.3798, 0.7590, 0.0056, 0.6959, 0.7343, 0.7611, 0.5858, 0.3682, 0.1216,
         0.8639, 0.6732, 0.4727],
        [0.2009, 0.6464, 0.3013, 0.3380, 0.2433, 0.2280, 0.3503, 0.4780, 0.7530,
         0.6716, 0.3355, 0.5648, 0.3473, 0.8290, 0.4116, 0.0442, 0.2717, 0.3790,
         0.1375, 0.4266, 0.3965, 0.0650, 0.1819, 0.0316, 0.4644, 0.4656, 0.7566,
         0.9178, 0.2597, 0.3103],
        [0.1869, 0.7015, 0.7595, 0.9563, 0.0684, 0.8513, 0.6640, 0.7727, 0.6334,
         0.7532, 0.8356, 0.5410, 0.2169, 0.8555, 0.4102, 0.4396, 0.8390, 0.1399,
         0.3835, 0.5007, 0.7783, 0.9780, 0.6950, 0.4501, 0.6164, 0.5525, 0.6139,
         0.7294, 0.2223, 0.6853],
        [0.7795, 0.7935, 0.7800, 0.6335, 0.4005, 0.5343, 0.0046, 0.7646, 0.0651,
         0.4529, 0.3207, 0.7586, 0.1579, 0.6287, 0.1613, 0.7730, 0.6849, 0.5714,
         0.9671, 0.0821, 0.7478, 0.9568, 0.8714, 0.3489, 0.1895, 0.5012, 0.0751,
         0.4864, 0.7313, 0.2190],
        [0.2558, 0.1372, 0.6245, 0.7994, 0.4376, 0.2629, 0.7305, 0.3682, 0.5189,
         0.3257, 0.9147, 0.1795, 0.9863, 0.5621, 0.7966, 0.6758, 0.0860, 0.7501,
         0.5412, 0.2903, 0.7496, 0.7252, 0.1888, 0.5759, 0.8577, 0.3962, 0.8956,
         0.9218, 0.5143, 0.8753],
        [0.1794, 0.0842, 0.6057, 0.9599, 0.8613, 0.0703, 0.3962, 0.6144, 0.0132,
         0.5081, 0.4494, 0.1555, 0.8192, 0.3031, 0.0482, 0.0165, 0.3774, 0.2113,
         0.0023, 0.5747, 0.8283, 0.6828, 0.0568, 0.7040, 0.4879, 0.4357, 0.9925,
         0.0231, 0.8678, 0.8258],
        [0.1071, 0.9391, 0.0361, 0.1474, 0.9148, 0.3629, 0.6864, 0.6754, 0.5313,
         0.4310, 0.9060, 0.9169, 0.6318, 0.7356, 0.4184, 0.9241, 0.5656, 0.2232,
         0.3448, 0.5907, 0.0555, 0.4838, 0.3880, 0.7759, 0.0140, 0.6659, 0.9488,
         0.7146, 0.4457, 0.7280],
        [0.0537, 0.2343, 0.3810, 0.1041, 0.2076, 0.3484, 0.5307, 0.5350, 0.1485,
         0.4263, 0.3602, 0.6709, 0.1884, 0.1323, 0.0592, 0.3663, 0.5269, 0.3464,
         0.7624, 0.2155, 0.2747, 0.8688, 0.2804, 0.1845, 0.8708, 0.4161, 0.1741,
         0.5426, 0.4196, 0.0278],
        [0.9674, 0.5630, 0.4510, 0.3879, 0.2646, 0.8017, 0.4210, 0.7576, 0.8178,
         0.2889, 0.4281, 0.0607, 0.6562, 0.6429, 0.7797, 0.3468, 0.0320, 0.6623,
         0.8383, 0.4069, 0.2485, 0.8856, 0.0212, 0.3830, 0.7633, 0.0271, 0.0627,
         0.0907, 0.5032, 0.6868],
        [0.9786, 0.8833, 0.2771, 0.8503, 0.9643, 0.2665, 0.9458, 0.1944, 0.4046,
         0.3345, 0.3423, 0.3892, 0.9186, 0.1303, 0.1543, 0.3074, 0.9204, 0.6381,
         0.3779, 0.9764, 0.2026, 0.4306, 0.0127, 0.8000, 0.1476, 0.4572, 0.8260,
         0.1681, 0.0999, 0.0838],
        [0.5961, 0.4434, 0.2229, 0.0074, 0.0565, 0.0156, 0.2243, 0.4342, 0.0322,
         0.3817, 0.8383, 0.9220, 0.5081, 0.6766, 0.3270, 0.7406, 0.1145, 0.1949,
         0.2362, 0.3956, 0.4517, 0.8267, 0.2308, 0.8147, 0.0995, 0.2056, 0.0133,
         0.8776, 0.2987, 0.0466],
        [0.6467, 0.2359, 0.2231, 0.7313, 0.1209, 0.9475, 0.9155, 0.7417, 0.3843,
         0.2120, 0.3383, 0.9352, 0.3552, 0.4050, 0.3245, 0.1852, 0.1851, 0.7660,
         0.5722, 0.9687, 0.7058, 0.7495, 0.8237, 0.9920, 0.0795, 0.5193, 0.9399,
         0.4980, 0.9298, 0.9310],
        [0.7099, 0.3879, 0.0630, 0.1548, 0.2988, 0.3489, 0.5426, 0.5429, 0.0053,
         0.7954, 0.7221, 0.6010, 0.6003, 0.3421, 0.3372, 0.7023, 0.3833, 0.0618,
         0.8759, 0.6212, 0.4012, 0.8726, 0.6640, 0.2590, 0.0107, 0.7525, 0.7647,
         0.4657, 0.1386, 0.9199],
        [0.7878, 0.1194, 0.2869, 0.7170, 0.8040, 0.7666, 0.4866, 0.5613, 0.0383,
         0.3685, 0.0986, 0.1444, 0.4574, 0.6214, 0.2698, 0.0079, 0.7405, 0.1205,
         0.9176, 0.4667, 0.2951, 0.6139, 0.9655, 0.1361, 0.8702, 0.5596, 0.9735,
         0.0497, 0.5684, 0.1042],
        [0.6720, 0.2498, 0.9661, 0.1693, 0.5408, 0.3751, 0.2529, 0.9149, 0.5986,
         0.7934, 0.6596, 0.1031, 0.0173, 0.1154, 0.7645, 0.3253, 0.9781, 0.6819,
         0.9567, 0.0724, 0.4812, 0.1672, 0.7238, 0.5022, 0.7053, 0.9059, 0.4651,
         0.8273, 0.1330, 0.6215],
        [0.2039, 0.0403, 0.3816, 0.4831, 0.9182, 0.2643, 0.9292, 0.5994, 0.6089,
         0.9758, 0.2332, 0.4003, 0.1500, 0.0101, 0.3218, 0.1062, 0.2415, 0.5920,
         0.4912, 0.5392, 0.6454, 0.6328, 0.7161, 0.8892, 0.0376, 0.3671, 0.6306,
         0.2864, 0.8676, 0.7922]])

我们将利用阈值、边条和线宽属性来根据我们的喜好改变张量的表示方式。我们还可以使用精度属性更改小数点后显示的位数。

# Example 1 - working
torch.set_printoptions(precision=3, threshold=10, edgeitems=5, linewidth=100, sci_mode=False)


# Printing the tensor again!
t2
tensor([[0.513, 0.575, 0.132, 0.002, 0.547,  ..., 0.938, 0.747, 0.960, 0.921, 0.321],
        [0.208, 0.277, 0.709, 0.374, 0.651,  ..., 0.769, 0.280, 0.668, 0.871, 0.096],
        [0.311, 0.285, 0.727, 0.800, 0.416,  ..., 0.600, 0.004, 0.538, 0.849, 0.236],
        [0.899, 0.782, 0.891, 0.876, 0.919,  ..., 0.583, 0.492, 0.616, 0.112, 0.059],
        [0.174, 0.015, 0.056, 0.125, 0.576,  ..., 0.368, 0.122, 0.864, 0.673, 0.473],
        ...,
        [0.647, 0.236, 0.223, 0.731, 0.121,  ..., 0.519, 0.940, 0.498, 0.930, 0.931],
        [0.710, 0.388, 0.063, 0.155, 0.299,  ..., 0.753, 0.765, 0.466, 0.139, 0.920],
        [0.788, 0.119, 0.287, 0.717, 0.804,  ..., 0.560, 0.974, 0.050, 0.568, 0.104],
        [0.672, 0.250, 0.966, 0.169, 0.541,  ..., 0.906, 0.465, 0.827, 0.133, 0.622],
        [0.204, 0.040, 0.382, 0.483, 0.918,  ..., 0.367, 0.631, 0.286, 0.868, 0.792]])

在这个方法中有3个配置文件可供我们使用:default,tiny 和 full。将配置文件名与其他属性一起传递是不正确的用法。在这种情况下,该函数忽略文件属性。

# Example 3 - breaking (to illustrate when it breaks)
torch.set_printoptions(precision=5, threshold=4, edgeitems=2, linewidth=50, profile="default", sci_mode=True)
t2
tensor([[5.12511e-01, 5.75067e-01,  ..., 9.21133e-01,
         3.21350e-01],
        [2.08170e-01, 2.76800e-01,  ..., 8.70968e-01,
         9.58731e-02],
        ...,
        [6.72050e-01, 2.49771e-01,  ..., 1.33016e-01,
         6.21548e-01],
        [2.03876e-01, 4.02817e-02,  ..., 8.67636e-01,
         7.92209e-01]])

Function 4 — Tensor.backward()

张量用于简化机器学习中所需的常见任务。为了执行一种流行的损失最小化技术---- 梯度下降法,我们需要计算损失函数的梯度(recall - derivates)。

PyTorch 通过使用 backward()方法来简化这个过程。注意: PyTorch 只有在一个张量的 require_grad 属性设置为 True 时才会计算它的梯度。

# Example 1 - working
m = torch.tensor(1.)
x = torch.tensor(3., requires_grad=True)
c = torch.tensor(5., requires_grad=True)

我们将用线性方程 y = mx + c 来求方程中每个变量 y 的偏导数。

print(x.grad)
tensor(1.)

在调用 y.backward() 方法并打印计算梯度之后,我们可以访问张量 x 的.grad 属性。

print(m.grad)
None

因为我们没有将 require_grad 选项设置为 True,所以调用 m 的 .grad 属性时不会得到结果。

print(c.grad)
tensor(1.)

再次调用 y.backward() 将导致 y 的张量的二阶微分。

Function 5 — torch.linspace()

linspace() 方法返回一个在设定范围内的一维张量。与随机生成数字的 rand()函数不同,返回的数字是 linspace() 中的等差数列序列的成员。

每个成员之间的差异由 steps 属性和范围(end ー start)指定。

# Example 1 - working
i = torch.linspace(start = 1, end = 10, steps=50, dtype=int) 
i
tensor([ 1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  4,
         4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  7,  7,  7,
         7,  7,  7,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9, 10])

输出张量包含50个等距数字,范围为1-10。dtype 属性是 int 类型,因此不存储小数位数。

# Example 2 - working


# Alternative way to use the linspace() with the out property. 
torch.linspace(start = 1, end = 25, steps=50, out=i)
tensor([ 1,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,  9,
         9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18,
        18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25])
注意,在使用 linspace() 方法创建张量时,dtype 值必须与输出张量定义的 dtype 相一致。
# Example 3 - breaking (to illustrate when it breaks)
torch.linspace(start = 1, end = 25, steps=50, out=i, dtype=float)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-37-89030ac639ec> in <module>
      1 # Example 3 - breaking (to illustrate when it breaks)
      ----> 2 torch.linspace(start = 1, end = 25, steps=50, out=i, dtype=float)
      
RuntimeError: dtype Double does not match dtype of out parameter (Long)

dtype属性不匹配

拓展

本文介绍了 PyTorch API 中可用的一些基本方法,以帮助您开始使用它。由于大部分实现都是从 NumPy 库中“借”用来的,因此便于 Python 开发人员的现有理解和经验基础上进行使用,所以这个 API 很容易入门。

·  END  ·

HAPPY LIFE

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值