这个是相当于笔记的性质,是为了记录一些pytorch的基本使用。
高强度CV,但是基本代码都是经过自己的试运行的。
OneHot
这个是可以通过给定int类型的一维Tensor,然后作为索引,从指定的矩阵中取值。可以用于生成onehot label。
import torch
ones = torch.eye(10)
print(ones)
label_index = torch.Tensor([1,2,3,4]).long()
onehot = ones.index_select(0,label_index)
print(onehot)
torch.max(Tensor,dim)
输入的第一个是Tensor,第二个就是取最大值的维度。
T = torch.Tensor([[1,2,3],[3,3,9],[88,48,0]])
value,index = torch.max(T,1)
print(value)
print(index)
矩阵乘法
torch.mm(A,B)