PyTorch写代码的一些技巧和常用操作(持续更新)

1. Data Masked(data_sample)

import random
import torch
data = torch.FloatTensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print("data:")
print(data)

num_mask = 1
sample = random.sample(range(len(data)), 1)
print(sample)
index = torch.ones(data.shape, dtype=torch.bool)
index[sample]=False
print(index)
data_sample = data[index].reshape(-1, data.shape[1])
print(data_sample)

在这里插入图片描述

2. Broadcast 广播机制

Broadcast 它能维度扩展和 expand 一样
Broadcast 是自动扩展,并且不需要拷贝数据,能够节省内存。

在这里插入图片描述
在这里插入图片描述

Broadcast存在的意义:
①实际的扩展。
②节省内存资源。
当没有维度的时候,首先添加一个size=1的维度,然后对size=1的所有维度进行扩展。

import torch

a = torch.rand(4, 32, 14, 14)
b = torch.rand(1, 32, 1, 1)
c = torch.rand(32, 1, 1)

# b [1, 32, 1, 1]=>[4, 32, 14, 14]
print((a + b).shape)
print((a+c).shape)

---------------------------------------------------------
torch.Size([4, 32, 14, 14])
torch.Size([4, 32, 14, 14])

Process finished with exit code 0

3. 合并与分割(merge or split)

  • Cat:concat 的缩写,表示拼接
  • stack:stack也是一种形式的拼接操作
  • split:split 按照长度进行拆分
  • chunk:chunk 按照数量进行拆分

3.1 cat 拼接

import torch

# 两个班级a和b,各有32个学生,8门成绩。
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
# 按照班级进行合并起来。
print(torch.cat([a, b], dim=0).shape)


---------------------------------------------------------
torch.Size([9, 32, 8])

Process finished with exit code 0

在这里插入图片描述

import torch

a1 = torch.rand(4, 3, 32, 32)
a2 = torch.rand(5, 3, 32, 32)
print(torch.cat([a1, a2], dim=0).shape)
print('====================================')
a3 = torch.rand(4, 1, 32, 32)
# print(torch.cat([a1, a3], dim=0))      # 这句报错。
print(torch.cat([a1, a3], dim=1).shape)

---------------------------------------------------------
torch.Size([9, 3, 32, 32])
====================================
torch.Size([4, 4, 32, 32])

Process finished with exit code 0

3.2 stack 创建新维度

stack操作两个维度必须一致

import torch
# stack操作两个维度必须一致
a1 = torch.rand(4, 3, 32, 32)
a2 = torch.rand(4, 3, 32, 32)
print(torch.cat([a1, a2], dim=1).shape)
print('====================================')
print(torch.stack([a1, a2], dim=1).shape)    # 各自创建一个新的维度。然后concat

a = torch.rand(32, 8)
b = torch.rand(32, 8)
print(torch.stack([a, b], dim=0).shape)

---------------------------------------------------------
torch.Size([4, 6, 32, 32])
====================================
torch.Size([4, 2, 3, 32, 32])                # 各自创建一个新的维度。然后concat
torch.Size([2, 32, 8])

Process finished with exit code 0

3.3 split 按长度拆分和 chunk 按数量拆分

.split (长度,dim) 第一参数表示拆分后的长度,第二个参数表示要拆分的维度。

import torch

c = torch.rand(2, 32, 8)
aa, bb = c.split([1, 1], dim=0)
print(aa.shape)
print(bb.shape)
print('====================================')
aa, bb = c.split(1, dim=0)
print(aa.shape)
print(bb.shape)

---------------------------------------------------------
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
====================================
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])

Process finished with exit code 0

chunk(数量,dim)第一参数表示要拆分后的数量,第二个参数表示要拆分的维度。

import torch

c = torch.rand(8, 32, 8)
aa, bb = c.chunk(2, dim=0)  # 第1个参数要拆分后的数量
print(aa.shape)
print(bb.shape)

---------------------------------------------------------
torch.Size([4, 32, 8])
torch.Size([4, 32, 8])

Process finished with exit code 0

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值