主要参考:Link
1. 导入torch并查看其版本
import torch
print(torch.__version__)
2. 随机种子
def set_up(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def train():
set_up(2021) //这里的2021可为任意整数
// ...
3. 查看张量的基本信息
调试代码时最多的就是查看张量的形状和维度信息
tensor = torch.zeros(2, 3, 4)
print(tensor.size()) # 查看张量的形状
print(tensor.dim()) # 查看张量的维度
4. torch与numpy转换
一般将ndarray转换为tensor比较多,因为将tensor转换到ndarray之后,运算会在cpu上运行,会大大降低运行速度。
tensor = torch.zeros(2, 3, 4)
np = tensor.cpu().numpy()
tensor = torch.from_numpy(np).float()
5. numpy数组转换为图像
常用于可视化,或加载图片数据
iamge = PIL.Image.fromarray(ndarray.astype(np.unit8)) //numpy数组转Image图像
ndarray = np.asarray(PIL.Image.open(path)) //Image图像转numpy数组
6. 张量拼接
cat:在给定维度上对输入的张量序列seq进行连接操作,所有的tensors必须为相同的shape或者为空。
x = torch.randn(2, 3)
torch.cat((x, x, x), 0)
torch.cat((x, x, x), 1)
stack:沿着新的维度拼接一个序列的tensors
torch.stack(tensors, dim=0, *, out=None) → Tensor
7. 展开张量
通过将输入重塑为一维张量来展平输入。如果传递了 start_dim 或 end_dim,则只有以 start_dim 开头并以 end_dim 结尾的尺寸被展平。输入中元素的顺序不变。
torch.flatten(input, start_dim=0, end_dim=-1)
t = torch.tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
torch.flatten(t)
# tensor([1, 2, 3, 4, 5, 6, 7, 8])
torch.flatten(t, start_dim=1)
# tensor([[1, 2, 3, 4],
# [5, 6, 7, 8]])
8.矩阵乘法
torch = torch.mm(mat1, mat2)
9. 模型定义
class ConvNet(nn.MOdule):
def __init__(ConvNet, num_classes=10):
super(ConvNet, self).__init__()
...
def forward(self, x):
...
model = ConvNet(number_classes).to(device) // 调用该类时即刻自动调用forward