import torch # 包 torch 包含了多维张量的数据结构以及基于其上的多种数学操作。另外,它也提供了多种工具,其中一些可以更有效地对张量和任意类型进行序列化
import torchvision # torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms
# ================================================================== #
# 1. Basic autograd example 1 #
# ================================================================== #
# Create tensors
x = torch.tensor(1., requires_grad = True) # 有一点因为只有float型才能求梯度
w = torch.tensor(2., requires_grad = True)
b = torch.tensor(3., requires_grad = True)
# Build a computational graph.
y = w * x + b
# Compute gradients.
y.backward()
# Print out the gradients.
print(x.grad) # x.grad = 2
print(w.grad) # w.grad = 1
print(b.grad)
# ================================================================== #
# 2. Basic autograd example 2 #
# ================
pytorch_basics
最新推荐文章于 2024-05-06 08:04:02 发布