pytorch中常用函数记录

import torch

if __name__ == "__main__":
    print("Hello World!")

    # increase dim
    test = torch.arange(24, device='cpu').view(2, 3, 4)  # shape[2,3,4]
    res0 = test[None]  # shape[1,2,3,4]
    res1 = test[..., None]  # shape[2,3,4,1]
    res2 = test[:, None]  # shape[2,1,3,4]
    res3 = test.unsqueeze(0)  # shape[1,2,3,4]
    res4 = test.unsqueeze(-1)  # shape[2,3,4,1]
    res5 = test.unsqueeze(1)  # shape[2,1,3,4]

    # decrease dim
    test1 = torch.arange(24, device='cpu').view(6, 1, 4, 1)  # shape[6,1,4,1]
    ans0 = test1.squeeze()  # shape[6,4]
    ans1 = test1.squeeze(1)  # shape[6,4,1]
    ans2 = test1.view(-1)  # shape[24]
    ans3 = test1.flatten()  # shape[24]

    # chunk
    boxes = torch.arange(12, device='cpu').view(3, 4)  # shape[3,4]
    x1, y1, x2, y2 = torch.chunk(boxes[..., None], 4, 1)  # 4 parts  dim=1  shape[3,1,1]
    x1y1, x2y2 = torch.chunk(boxes[..., None], 2, 1)  # 2 parts  dim=1   shape[3,2,1]

    # split
    pred = torch.randn(16, 85, 20, 20)  # shape[16,85,20,20]
    box, score, prob = torch.split(pred, [4, 1, 80], 1)  # dim=1  85=4+1+80
    print(box.shape)  # shape[16,4,20,20]
    print(score.shape)  # shape[16,1,20,20]
    print(prob.shape)  # shape[16,80,20,20]

    # cat ---- concat
    x3y3 = torch.arange(0, 6, device='cpu').view(3, 2)  # shape[3,2]
    x4y4 = torch.arange(6, 12, device='cpu').view(3, 2)  # shape[3,2]
    boxes = torch.cat((x3y3, x4y4), dim=0)  # shape[6,2]
    boxes2 = torch.cat((x3y3, x4y4), dim=1)  # shape[3,4]

    # stack ---- new dim concat
    x = torch.randn(16, 20, 20)  # shape[16,20,20]
    y = torch.randn(16, 20, 20)  # shape[16,20,20]
    w = torch.randn(16, 20, 20)  # shape[16,20,20]
    h = torch.randn(16, 20, 20)  # shape[16,20,20]
    res = torch.stack([x, y, w, h], 1)  # shape[16,4,20,20]

    # amin() max()
    predict = torch.arange(0, 6, device='cpu').view(3, 2)  # shape[3,2]
    p1 = predict.amax(1)  # shape[3]
    m, ind = predict.max(1)  # shape[3]  shape[3]
    m2, ind2 = predict.max(1, keepdim=True)  # shape[3,1]  shape[3,1]

    # where
    cls = torch.arange(0, 6, device='cpu').view(3, 2)  # shape[3,2]
    i, j = torch.where(cls > 2)
    print(i)  # row indexes  [1,2,2]  shape[3]
    print(j)  # cil indexes  [1,0,1]  shape[3]
    w = torch.arange(0, 6, device='cpu').view(3, 2)
    s = torch.where(w > 2, w, torch.full_like(w, 0))  # s = w > 0 ? w : 0   shape[3,2]

    # >  <
    a = torch.tensor([0.1, 2.4, 5.4, 6.2], device='cpu').view(4, 1, 1)  # shape[4,1,1]
    b = torch.tensor([0.5, 1.4, 2.4], device='cpu').view(1, 1, 3)  # shape[1,1,3]
    c = a > b   # shape[4,1,3]  bool
    n = torch.arange(0, 12, device='cpu').view(4, 1, 3)  # shape[4,1,3]
    mask = n * c  # shape[4,1,3]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值