欢迎关注 “小白玩转Python”,发现更多 “有趣”
Function 1 — torch.device()
PyTorch 是 Facebook 开发的一个开源库,在数据科学家中非常流行。其崛起的主要原因之一是 GPU 对开发者的内置支持。
torch.device 允许您指定负责将张量加载到内存中的设备类型,该函数需要一个指定设备类型的字符串参数。
你甚至可以传递一个序号,比如设备索引。或者不指定 PyTorch 使用当前可用的设备。
# Example 1 - working
torch.device('cuda:1')
device(type='cuda', index=1)
我们选择要在运行时存储张量的设备类型。注意,我们已经指定了我们的设备类型为 cuda,并且将序号连接到由‘ :’分隔的相同字符串中。
# Example 2 - working
device = torch.device('cuda', 1)
device
device(type='cuda', index=1)
代码 2 中可以看出:通过为设备类型和索引传入单独的参数来实现相同的结果。
t1 = torch.tensor(2.0, device=torch.device('cpu'))
torch.device() 中预期的设备类型是 cpu、 cuda、 mkldnn、 opengl、 opencl、 ideep、 hip 和 msnpu。为了正确使用此方法,设备类型应该存在于预期设备列表中。
让我们看看当我们尝试将 GPU 指定为设备类型时会发生什么。
# Example 3 - breaking
torch.device('gpu', 1)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-7-7b7286f1daee> in <module>
1 # Example 3 - breaking
----> 2 torch.device('gpu', 1)
RuntimeError: Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu device type at start of device string: gpu
指定的设备类型应该在运行笔记本的机器上可用。如果不这样做,将导致类似上述错误。
在下列代码中,我们使用一个需要 NVIDIA GPU 的 cuda 设备类型定义了一个张量。由于我们的机器没有任何可用的 GPU,内核抛出了一个运行时错误。
# Example 3 - breaking (to illustrate when it breaks)
torch.tensor(2.0, device=torch.device('cuda', 1))
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-8-1fabe5391fef> in <module>
1 # Example 3 - breaking (to illustrate when it breaks)
----> 2 torch.tensor(2.0, device=torch.device('cuda', 1))
/srv/conda/envs/notebook/lib/python3.7/site-packages/torch/cuda/__init__.py in _lazy_init()
147 raise RuntimeError(
148 "Cannot re-initialize CUDA in forked subprocess. " + msg)
--> 149 _check_driver()
150 if _cudart is None:
151 raise AssertionError(
/srv/conda/envs/notebook/lib/python3.7/site-packages/torch/cuda/__init__.py in _check_driver()
52 Found no NVIDIA driver on your system. Please check that you
53 have an NVIDIA GPU and installed a driver from
---> 54 http://www.nvidia.com/Download/index.aspx""")
55 else:
56 # TODO: directly link to the alternative bin that needs install
AssertionError:
Found no NVIDIA driver on your system. Please check that you
have an NVIDIA GPU and installed a driver from
http://www.nvidia.com/Download/index.aspx
Function 2 — torch.view()
torch.view() 方法将张量(无论是矢量、矩阵还是标量)的视图更改为所需的形状。转换后的张量修改了表示维数,但保留了相同的数据类型。
让我们来看一个例子:
定义一个样本张量 x,并将其维数转换为另一个张量 z。
# Example 1 - working
x = torch.randn(4, 5)
x
tensor([[-1.3324, 0.4927, -0.7445, 0.0132, -1.0947],
[ 1.4211, -1.6057, 0.9154, 1.1111, 0.8435],
[ 0.9480, -0.2812, 1.2160, 1.1731, 1.0137],
[ 1.8901, -1.3294, 0.5607, 0.9897, -1.3568]])
在上述示例中,我们将4 × 5矩阵的维数转换为单行向量。在默认情况下连续打印行的元素。变换后的视图中的每个行元素都以原张量中的相同顺序出现。
此外,新张量的形状必须支持原张量中相同数目的元素。你不能在一个5 x 3的视图中存储一个4 x 5形状的张量。
# Example 2 - working
x = torch.randn(4, 8)
z = x.view(-1, 16)
z.size()
torch.Size([2, 16])
如上述示例,我们将使用 -1 来表示一个维度。PyTorch 自动从其他维度解释未知维度。
# Example 3 - breaking
x.view(2, 8) # This will throw an error as the new shape will only stor 2 x 8 = 16 elements.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-42-55503410d9d7> in <module>
1 # Example 3 - breaking
----> 2 x.view(2, 8) # This will throw an error as the new shape will only stor 2 x 8 = 16 elements.
RuntimeError: shape '[2, 8]' is invalid for input of size 32
请注意,一次只能推断出一个维度。对多个维度使用 -1只会导致运行时错误!
PyTorch 最多只允许传入一个 -1。如果传递多个 -1,它将抛出一个运行时错误。
Function 3 — torch.set_printoptions()
很多时候,您希望在执行某些任务之前打印张量的内容。为此,在打印张量时可能需要更改显示表示形式。
使用 set_printoptions,您可以调整精度级别、线宽、结果阈值等属性。
在我们的例子中,我们用一个表示20 x 30矩阵的张量,这样一个巨大的矩阵是不经常需要的。打印张量变量后面的一个通用用例是查看前几行和后几行。
t2 = torch.rand(20, 30)
t2
tensor([[0.5125, 0.5751, 0.1321, 0.0023, 0.5471, 0.0895, 0.9011, 0.9573, 0.6723,
0.8132, 0.7488, 0.1860, 0.5075, 0.3734, 0.6644, 0.3321, 0.2505, 0.2124,
0.3036, 0.9942, 0.7319, 0.6374, 0.1483, 0.3731, 0.8582, 0.9378, 0.7472,
0.9595, 0.9211, 0.3213],
[0.2082, 0.2768, 0.7095, 0.3739, 0.6506, 0.6024, 0.6385, 0.1580, 0.8068,
0.8255, 0.2024, 0.8948, 0.4900, 0.2746, 0.1933, 0.8375, 0.9401, 0.2070,
0.2478, 0.3216, 0.9916, 0.8356, 0.4052, 0.6393, 0.6402, 0.7690, 0.2805,
0.6679, 0.8710, 0.0959],
[0.3106, 0.2850, 0.7273, 0.8001, 0.4156, 0.5507, 0.1308, 0.5954, 0.2465,
0.5703, 0.1347, 0.8934, 0.5148, 0.5169, 0.8892, 0.8679, 0.9969, 0.0386,
0.3858, 0.0766, 0.1805, 0.2141, 0.0704, 0.6004, 0.2985, 0.6005, 0.0040,
0.5385, 0.8494, 0.2355],
[0.8993, 0.7823, 0.8907, 0.8764, 0.9194, 0.0349, 0.3185, 0.0500, 0.6937,
0.5656, 0.1844, 0.1586, 0.7471, 0.8469, 0.8463, 0.3646, 0.0494, 0.7467,
0.4865, 0.6169, 0.9768, 0.5953, 0.7254, 0.6142, 0.5560, 0.5831, 0.4925,
0.6164, 0.1118, 0.0587],
[0.1739, 0.0149, 0.0556, 0.1247, 0.5763, 0.4129, 0.4825, 0.6093, 0.5256,
0.3989, 0.5695, 0.5199, 0.7978, 0.6284, 0.7769, 0.1155, 0.9871, 0.8168,
0.3798, 0.7590, 0.0056, 0.6959, 0.7343, 0.7611, 0.5858, 0.3682, 0.1216,
0.8639, 0.6732, 0.4727],
[0.2009, 0.6464, 0.3013, 0.3380, 0.2433, 0.2280, 0.3503, 0.4780, 0.7530,
0.6716, 0.3355, 0.5648, 0.3473, 0.8290, 0.4116, 0.0442, 0.2717, 0.3790,
0.1375, 0.4266, 0.3965, 0.0650, 0.1819, 0.0316, 0.4644, 0.4656, 0.7566,
0.9178, 0.2597, 0.3103],
[0.1869, 0.7015, 0.7595, 0.9563, 0.0684, 0.8513, 0.6640, 0.7727, 0.6334,
0.7532, 0.8356, 0.5410, 0.2169, 0.8555, 0.4102, 0.4396, 0.8390, 0.1399,
0.3835, 0.5007, 0.7783, 0.9780, 0.6950, 0.4501, 0.6164, 0.5525, 0.6139,
0.7294, 0.2223, 0.6853],
[0.7795, 0.7935, 0.7800, 0.6335, 0.4005, 0.5343, 0.0046, 0.7646, 0.0651,
0.4529, 0.3207, 0.7586, 0.1579, 0.6287, 0.1613, 0.7730, 0.6849, 0.5714,
0.9671, 0.0821, 0.7478, 0.9568, 0.8714, 0.3489, 0.1895, 0.5012, 0.0751,
0.4864, 0.7313, 0.2190],
[0.2558, 0.1372, 0.6245, 0.7994, 0.4376, 0.2629, 0.7305, 0.3682, 0.5189,
0.3257, 0.9147, 0.1795, 0.9863, 0.5621, 0.7966, 0.6758, 0.0860, 0.7501,
0.5412, 0.2903, 0.7496, 0.7252, 0.1888, 0.5759, 0.8577, 0.3962, 0.8956,
0.9218, 0.5143, 0.8753],
[0.1794, 0.0842, 0.6057, 0.9599, 0.8613, 0.0703, 0.3962, 0.6144, 0.0132,
0.5081, 0.4494, 0.1555, 0.8192, 0.3031, 0.0482, 0.0165, 0.3774, 0.2113,
0.0023, 0.5747, 0.8283, 0.6828, 0.0568, 0.7040, 0.4879, 0.4357, 0.9925,
0.0231, 0.8678, 0.8258],
[0.1071, 0.9391, 0.0361, 0.1474, 0.9148, 0.3629, 0.6864, 0.6754, 0.5313,
0.4310, 0.9060, 0.9169, 0.6318, 0.7356, 0.4184, 0.9241, 0.5656, 0.2232,
0.3448, 0.5907, 0.0555, 0.4838, 0.3880, 0.7759, 0.0140, 0.6659, 0.9488,
0.7146, 0.4457, 0.7280],
[0.0537, 0.2343, 0.3810, 0.1041, 0.2076, 0.3484, 0.5307, 0.5350, 0.1485,
0.4263, 0.3602, 0.6709, 0.1884, 0.1323, 0.0592, 0.3663, 0.5269, 0.3464,
0.7624, 0.2155, 0.2747, 0.8688, 0.2804, 0.1845, 0.8708, 0.4161, 0.1741,
0.5426, 0.4196, 0.0278],
[0.9674, 0.5630, 0.4510, 0.3879, 0.2646, 0.8017, 0.4210, 0.7576, 0.8178,
0.2889, 0.4281, 0.0607, 0.6562, 0.6429, 0.7797, 0.3468, 0.0320, 0.6623,
0.8383, 0.4069, 0.2485, 0.8856, 0.0212, 0.3830, 0.7633, 0.0271, 0.0627,
0.0907, 0.5032, 0.6868],
[0.9786, 0.8833, 0.2771, 0.8503, 0.9643, 0.2665, 0.9458, 0.1944, 0.4046,
0.3345, 0.3423, 0.3892, 0.9186, 0.1303, 0.1543, 0.3074, 0.9204, 0.6381,
0.3779, 0.9764, 0.2026, 0.4306, 0.0127, 0.8000, 0.1476, 0.4572, 0.8260,
0.1681, 0.0999, 0.0838],
[0.5961, 0.4434, 0.2229, 0.0074, 0.0565, 0.0156, 0.2243, 0.4342, 0.0322,
0.3817, 0.8383, 0.9220, 0.5081, 0.6766, 0.3270, 0.7406, 0.1145, 0.1949,
0.2362, 0.3956, 0.4517, 0.8267, 0.2308, 0.8147, 0.0995, 0.2056, 0.0133,
0.8776, 0.2987, 0.0466],
[0.6467, 0.2359, 0.2231, 0.7313, 0.1209, 0.9475, 0.9155, 0.7417, 0.3843,
0.2120, 0.3383, 0.9352, 0.3552, 0.4050, 0.3245, 0.1852, 0.1851, 0.7660,
0.5722, 0.9687, 0.7058, 0.7495, 0.8237, 0.9920, 0.0795, 0.5193, 0.9399,
0.4980, 0.9298, 0.9310],
[0.7099, 0.3879, 0.0630, 0.1548, 0.2988, 0.3489, 0.5426, 0.5429, 0.0053,
0.7954, 0.7221, 0.6010, 0.6003, 0.3421, 0.3372, 0.7023, 0.3833, 0.0618,
0.8759, 0.6212, 0.4012, 0.8726, 0.6640, 0.2590, 0.0107, 0.7525, 0.7647,
0.4657, 0.1386, 0.9199],
[0.7878, 0.1194, 0.2869, 0.7170, 0.8040, 0.7666, 0.4866, 0.5613, 0.0383,
0.3685, 0.0986, 0.1444, 0.4574, 0.6214, 0.2698, 0.0079, 0.7405, 0.1205,
0.9176, 0.4667, 0.2951, 0.6139, 0.9655, 0.1361, 0.8702, 0.5596, 0.9735,
0.0497, 0.5684, 0.1042],
[0.6720, 0.2498, 0.9661, 0.1693, 0.5408, 0.3751, 0.2529, 0.9149, 0.5986,
0.7934, 0.6596, 0.1031, 0.0173, 0.1154, 0.7645, 0.3253, 0.9781, 0.6819,
0.9567, 0.0724, 0.4812, 0.1672, 0.7238, 0.5022, 0.7053, 0.9059, 0.4651,
0.8273, 0.1330, 0.6215],
[0.2039, 0.0403, 0.3816, 0.4831, 0.9182, 0.2643, 0.9292, 0.5994, 0.6089,
0.9758, 0.2332, 0.4003, 0.1500, 0.0101, 0.3218, 0.1062, 0.2415, 0.5920,
0.4912, 0.5392, 0.6454, 0.6328, 0.7161, 0.8892, 0.0376, 0.3671, 0.6306,
0.2864, 0.8676, 0.7922]])
我们将利用阈值、边条和线宽属性来根据我们的喜好改变张量的表示方式。我们还可以使用精度属性更改小数点后显示的位数。
# Example 1 - working
torch.set_printoptions(precision=3, threshold=10, edgeitems=5, linewidth=100, sci_mode=False)
# Printing the tensor again!
t2
tensor([[0.513, 0.575, 0.132, 0.002, 0.547, ..., 0.938, 0.747, 0.960, 0.921, 0.321],
[0.208, 0.277, 0.709, 0.374, 0.651, ..., 0.769, 0.280, 0.668, 0.871, 0.096],
[0.311, 0.285, 0.727, 0.800, 0.416, ..., 0.600, 0.004, 0.538, 0.849, 0.236],
[0.899, 0.782, 0.891, 0.876, 0.919, ..., 0.583, 0.492, 0.616, 0.112, 0.059],
[0.174, 0.015, 0.056, 0.125, 0.576, ..., 0.368, 0.122, 0.864, 0.673, 0.473],
...,
[0.647, 0.236, 0.223, 0.731, 0.121, ..., 0.519, 0.940, 0.498, 0.930, 0.931],
[0.710, 0.388, 0.063, 0.155, 0.299, ..., 0.753, 0.765, 0.466, 0.139, 0.920],
[0.788, 0.119, 0.287, 0.717, 0.804, ..., 0.560, 0.974, 0.050, 0.568, 0.104],
[0.672, 0.250, 0.966, 0.169, 0.541, ..., 0.906, 0.465, 0.827, 0.133, 0.622],
[0.204, 0.040, 0.382, 0.483, 0.918, ..., 0.367, 0.631, 0.286, 0.868, 0.792]])
在这个方法中有3个配置文件可供我们使用:default,tiny 和 full。将配置文件名与其他属性一起传递是不正确的用法。在这种情况下,该函数忽略文件属性。
# Example 3 - breaking (to illustrate when it breaks)
torch.set_printoptions(precision=5, threshold=4, edgeitems=2, linewidth=50, profile="default", sci_mode=True)
t2
tensor([[5.12511e-01, 5.75067e-01, ..., 9.21133e-01,
3.21350e-01],
[2.08170e-01, 2.76800e-01, ..., 8.70968e-01,
9.58731e-02],
...,
[6.72050e-01, 2.49771e-01, ..., 1.33016e-01,
6.21548e-01],
[2.03876e-01, 4.02817e-02, ..., 8.67636e-01,
7.92209e-01]])
Function 4 — Tensor.backward()
张量用于简化机器学习中所需的常见任务。为了执行一种流行的损失最小化技术---- 梯度下降法,我们需要计算损失函数的梯度(recall - derivates)。
PyTorch 通过使用 backward()方法来简化这个过程。注意: PyTorch 只有在一个张量的 require_grad 属性设置为 True 时才会计算它的梯度。
# Example 1 - working
m = torch.tensor(1.)
x = torch.tensor(3., requires_grad=True)
c = torch.tensor(5., requires_grad=True)
我们将用线性方程 y = mx + c 来求方程中每个变量 y 的偏导数。
print(x.grad)
tensor(1.)
在调用 y.backward() 方法并打印计算梯度之后,我们可以访问张量 x 的.grad 属性。
print(m.grad)
None
因为我们没有将 require_grad 选项设置为 True,所以调用 m 的 .grad 属性时不会得到结果。
print(c.grad)
tensor(1.)
再次调用 y.backward() 将导致 y 的张量的二阶微分。
Function 5 — torch.linspace()
linspace() 方法返回一个在设定范围内的一维张量。与随机生成数字的 rand()函数不同,返回的数字是 linspace() 中的等差数列序列的成员。
每个成员之间的差异由 steps 属性和范围(end ー start)指定。
# Example 1 - working
i = torch.linspace(start = 1, end = 10, steps=50, dtype=int)
i
tensor([ 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4,
4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7,
7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 10])
输出张量包含50个等距数字,范围为1-10。dtype 属性是 int 类型,因此不存储小数位数。
# Example 2 - working
# Alternative way to use the linspace() with the out property.
torch.linspace(start = 1, end = 25, steps=50, out=i)
tensor([ 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9,
9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18,
18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25])
注意,在使用 linspace() 方法创建张量时,dtype 值必须与输出张量定义的 dtype 相一致。
# Example 3 - breaking (to illustrate when it breaks)
torch.linspace(start = 1, end = 25, steps=50, out=i, dtype=float)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-37-89030ac639ec> in <module>
1 # Example 3 - breaking (to illustrate when it breaks)
----> 2 torch.linspace(start = 1, end = 25, steps=50, out=i, dtype=float)
RuntimeError: dtype Double does not match dtype of out parameter (Long)
dtype属性不匹配
拓展
本文介绍了 PyTorch API 中可用的一些基本方法,以帮助您开始使用它。由于大部分实现都是从 NumPy 库中“借”用来的,因此便于 Python 开发人员的现有理解和经验基础上进行使用,所以这个 API 很容易入门。
· END ·
HAPPY LIFE