@创建于:2021.10.15
原文介绍参见 thorough-pytorch/第二章 PyTorch基础知识/
# -*- coding:UTF-8 -*-
# datetime: 2021/10/12 16:48
# software: PyCharm
"""
文件说明:
学习张量知识
https://github.com/datawhalechina/thorough-pytorch
"""
from __future__ import print_function
import torch
def test1():
# 构造一个随机初始化的矩阵
x = torch.rand(4, 3)
print("torch random:\n", x)
# 构造一个矩阵全为 0,而且数据类型是 long.
x = torch.zeros(4, 3, dtype=torch.long)
print("torch zeros with long type:\n", x)
x = torch.tensor([5, 5, 3])
print("torch tensor with list:\n", x)
def test2():
t2 = torch.tensor([[0,1,2], [3,4,5]])
print(t2)
print('数据是{}'.format(t2))
print('大小:{}'.format(t2.size()))
print('维度:{}'.format(t2.dim()))
print('元素个数:{}'.format(t2.numel()))
print('元素类型:{}'.format(t2.dtype))
print(t2.reshape(3,2))#重新组织元素
print(t2+1)
def test3():
x = torch.rand(4, 3)
print("torch random:", x)
x = x.new_ones(4, 3, dtype=torch.double)
print("torch new_ones:", x)
x = torch.randn_like(x, dtype=torch.float16)
print("torch randn_like:", x)
print("size:", x.size())
print("shape:", x.shape)
print("dim:", x.dim())
def test4():
x = torch.rand(size=(4, 3), dtype=torch.float16)
print("x:\n", x)
y = torch.rand(4, 3, dtype=torch.float16)
print("y:\n", y)
print("x+y\n", x+y)
print("add(x,y)\n", torch.add(input=x, other=y))
z = torch.empty(size=(4, 3))
torch.add(input=x, other=y, out=z)
print("add(x,y,out=z):\n", z)
y.add_(other=x)
print("y_add\n", y)
def test5():
x = torch.rand(size=(4, 3), dtype=torch.float16)
print("1-- x:\n", x)
# y = x.new_ones(4, 3)
# print("y:\n", y)
print("x[:, 0] =", x[:, 0])
y = x[1, :]
print("1-- y = ", y)
y += 1
print("2-- y = ", y)
print("2-- x:\n", x)
def test6():
x = torch.randn(4, 4)
print("x:\n", x)
y = x.view(16)
z = x.view(-1,8)
print(x.size(), y.size(), z.size())
x += 10
print("z:\n", z)
t = x.clone()
t += 10
print("t:\n", t)
print("t[0,0].item():\n", t[0,0].item())
def test7():
# 广播机制
x = torch.arange(1,3).view(1, 2)
y = torch.arange(1,4).view(3, 1)
z = x + y
print("x:\n", x)
print("y:\n", y)
print("z.size: {}, z:\n{}".format(z.size(), z))
def test10():
x = torch.ones(size=(2,2), requires_grad=True)
y = x**2
print("x\n", x)
print("y\n", y)
print("y.grad_fn\n", y.grad_fn)
z = y*y*3
print("z\n", z)
out = z.mean()
print("out\n", out)
a = torch.randn(size=(2, 2))
a = (a*3)/(a-1)
print("1 - a.requires_grad = ", a.requires_grad)
# a.requires_grad_(mode=True)
a.requires_grad_(True)
print("2 - a.requires_grad = ", a.requires_grad)
b = (a * a).sum()
print("b.requires_grad = ", b.requires_grad)
print("b.grad_fn = ", b.grad_fn)
def test11():
x = torch.ones(2, 2, requires_grad=True)
y = x**2
z = y*y*3
out = z.sum()
print("out.backward() =", out.backward())
# print("out.backward(torch.tensor(1.)) = ", out.backward(torch.tensor(1.)))
print("out x.is_leaf =", x.is_leaf)
print("out x.grad =", x.grad)
print("1- out y.is_leaf =", y.is_leaf)
y.retain_grad()
print("2- out y.is_leaf =", y.is_leaf)
print("out y.grad =", y.grad)
out2 = x.sum()
out2.backward()
print("out2 x.grad", x.grad)
out3 = x.sum()
x.grad.data.zero_()
out3.backward()
print("out3 x.grad", x.grad)
def test12():
x = torch.randn(3, requires_grad=True)
print("x:\n", x)
y = x*2
i = 0
while y.data.norm() < 1000:
y = y*2
i += 1
print("y:\n", y)
print("i =", i)
print(type(y))
v = torch.tensor([0.1, 1, 0.0001], dtype=torch.float)
# v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)
print("x.grad =", x.grad)
print("x.requires_grad =", x.requires_grad)
print((x**2).requires_grad)
print((x ** 2).requires_grad)
with torch.no_grad():
print((x ** 2).requires_grad)
def test13():
x = torch.ones(2, 3, requires_grad=True)
print("type(x) =", type(x))
print("type(x.data) =", type(x.data))
print("x =", x)
print("x.data =", x.data)
print("x.requires_grad =", x.requires_grad)
print("x.data.requires_grad =", x.data.requires_grad)
y = 2 * x
x.data[:,2] = torch.tensor([2, 4])
print("x =", x)
print("x.data =", x.data)
print("y =", y)
print("y.data =", y.data)
y.sum().backward()
print("x =", x)
print("x.grad =", x.grad)
if __name__ == "__main__":
# test1()
# test2()
# test3()
# test4()
# test5()
# test6()
# test7()
# test10()
# test11()
# test12()
test13()