import torch
if __name__ == "__main__":
print("Hello World!")
# increase dim
test = torch.arange(24, device='cpu').view(2, 3, 4) # shape[2,3,4]
res0 = test[None] # shape[1,2,3,4]
res1 = test[..., None] # shape[2,3,4,1]
res2 = test[:, None] # shape[2,1,3,4]
res3 = test.unsqueeze(0) # shape[1,2,3,4]
res4 = test.unsqueeze(-1) # shape[2,3,4,1]
res5 = test.unsqueeze(1) # shape[2,1,3,4]
# decrease dim
test1 = torch.arange(24, device='cpu').view(6, 1, 4, 1) # shape[6,1,4,1]
ans0 = test1.squeeze() # shape[6,4]
ans1 = test1.squeeze(1) # shape[6,4,1]
ans2 = test1.view(-1) # shape[24]
ans3 = test1.flatten() # shape[24]
# chunk
boxes = torch.arange(12, device='cpu').view(3, 4) # shape[3,4]
x1, y1, x2, y2 = torch.chunk(boxes[..., None], 4, 1) # 4 parts dim=1 shape[3,1,1]
x1y1, x2y2 = torch.chunk(boxes[..., None], 2, 1) # 2 parts dim=1 shape[3,2,1]
# split
pred = torch.randn(16, 85, 20, 20) # shape[16,85,20,20]
box, score, prob = torch.split(pred, [4, 1, 80], 1) # dim=1 85=4+1+80
print(box.shape) # shape[16,4,20,20]
print(score.shape) # shape[16,1,20,20]
print(prob.shape) # shape[16,80,20,20]
# cat ---- concat
x3y3 = torch.arange(0, 6, device='cpu').view(3, 2) # shape[3,2]
x4y4 = torch.arange(6, 12, device='cpu').view(3, 2) # shape[3,2]
boxes = torch.cat((x3y3, x4y4), dim=0) # shape[6,2]
boxes2 = torch.cat((x3y3, x4y4), dim=1) # shape[3,4]
# stack ---- new dim concat
x = torch.randn(16, 20, 20) # shape[16,20,20]
y = torch.randn(16, 20, 20) # shape[16,20,20]
w = torch.randn(16, 20, 20) # shape[16,20,20]
h = torch.randn(16, 20, 20) # shape[16,20,20]
res = torch.stack([x, y, w, h], 1) # shape[16,4,20,20]
# amin() max()
predict = torch.arange(0, 6, device='cpu').view(3, 2) # shape[3,2]
p1 = predict.amax(1) # shape[3]
m, ind = predict.max(1) # shape[3] shape[3]
m2, ind2 = predict.max(1, keepdim=True) # shape[3,1] shape[3,1]
# where
cls = torch.arange(0, 6, device='cpu').view(3, 2) # shape[3,2]
i, j = torch.where(cls > 2)
print(i) # row indexes [1,2,2] shape[3]
print(j) # cil indexes [1,0,1] shape[3]
w = torch.arange(0, 6, device='cpu').view(3, 2)
s = torch.where(w > 2, w, torch.full_like(w, 0)) # s = w > 0 ? w : 0 shape[3,2]
# > <
a = torch.tensor([0.1, 2.4, 5.4, 6.2], device='cpu').view(4, 1, 1) # shape[4,1,1]
b = torch.tensor([0.5, 1.4, 2.4], device='cpu').view(1, 1, 3) # shape[1,1,3]
c = a > b # shape[4,1,3] bool
n = torch.arange(0, 12, device='cpu').view(4, 1, 3) # shape[4,1,3]
mask = n * c # shape[4,1,3]
pytorch中常用函数记录
于 2023-09-18 16:04:59 首次发布