Pytorch常用代码(不时更新)

主要参考:Link

1. 导入torch并查看其版本

import torch
print(torch.__version__)

2. 随机种子

def set_up(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    
def train():    
	set_up(2021) //这里的2021可为任意整数  
    // ...

3. 查看张量的基本信息

调试代码时最多的就是查看张量的形状和维度信息

tensor = torch.zeros(2, 3, 4)
print(tensor.size())  # 查看张量的形状
print(tensor.dim()) # 查看张量的维度

4. torch与numpy转换

一般将ndarray转换为tensor比较多,因为将tensor转换到ndarray之后,运算会在cpu上运行,会大大降低运行速度。

tensor = torch.zeros(2, 3, 4)
np = tensor.cpu().numpy()
tensor = torch.from_numpy(np).float()

5. numpy数组转换为图像

常用于可视化,或加载图片数据

iamge = PIL.Image.fromarray(ndarray.astype(np.unit8))  //numpy数组转Image图像

ndarray = np.asarray(PIL.Image.open(path)) //Image图像转numpy数组

6. 张量拼接

cat:在给定维度上对输入的张量序列seq进行连接操作,所有的tensors必须为相同的shape或者为空。

x = torch.randn(2, 3)
torch.cat((x, x, x), 0)
torch.cat((x, x, x), 1)

stack:沿着新的维度拼接一个序列的tensors

torch.stack(tensors, dim=0, *, out=None) → Tensor

7. 展开张量

通过将输入重塑为一维张量来展平输入。如果传递了 start_dim 或 end_dim,则只有以 start_dim 开头并以 end_dim 结尾的尺寸被展平。输入中元素的顺序不变。

torch.flatten(input, start_dim=0, end_dim=-1)

t = torch.tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                   [7, 8]]])
torch.flatten(t)
# tensor([1, 2, 3, 4, 5, 6, 7, 8])
torch.flatten(t, start_dim=1)
# tensor([[1, 2, 3, 4],
#        [5, 6, 7, 8]])

8.矩阵乘法

torch = torch.mm(mat1, mat2)

9. 模型定义

class ConvNet(nn.MOdule):
    def __init__(ConvNet, num_classes=10):
        super(ConvNet, self).__init__()
        ...
    def forward(self, x):
        ...

model = ConvNet(number_classes).to(device)  // 调用该类时即刻自动调用forward
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值