1. Data Masked(data_sample)
import random
import torch
data = torch.FloatTensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print("data:")
print(data)
num_mask = 1
sample = random.sample(range(len(data)), 1)
print(sample)
index = torch.ones(data.shape, dtype=torch.bool)
index[sample]=False
print(index)
data_sample = data[index].reshape(-1, data.shape[1])
print(data_sample)
2. Broadcast 广播机制
Broadcast 它能维度扩展和 expand 一样
Broadcast 是自动扩展,并且不需要拷贝数据,能够节省内存。
Broadcast存在的意义:
①实际的扩展。
②节省内存资源。
当没有维度的时候,首先添加一个size=1的维度,然后对size=1的所有维度进行扩展。
import torch
a = torch.rand(4, 32, 14, 14)
b = torch.rand(1, 32, 1, 1)
c = torch.rand(32, 1, 1)
# b [1, 32, 1, 1]=>[4, 32, 14, 14]
print((a + b).shape)
print((a+c).shape)
---------------------------------------------------------
torch.Size([4, 32, 14, 14])
torch.Size([4, 32, 14, 14])
Process finished with exit code 0
3. 合并与分割(merge or split)
- Cat:concat 的缩写,表示拼接
- stack:stack也是一种形式的拼接操作
- split:split 按照长度进行拆分
- chunk:chunk 按照数量进行拆分
3.1 cat 拼接
import torch
# 两个班级a和b,各有32个学生,8门成绩。
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
# 按照班级进行合并起来。
print(torch.cat([a, b], dim=0).shape)
---------------------------------------------------------
torch.Size([9, 32, 8])
Process finished with exit code 0
import torch
a1 = torch.rand(4, 3, 32, 32)
a2 = torch.rand(5, 3, 32, 32)
print(torch.cat([a1, a2], dim=0).shape)
print('====================================')
a3 = torch.rand(4, 1, 32, 32)
# print(torch.cat([a1, a3], dim=0)) # 这句报错。
print(torch.cat([a1, a3], dim=1).shape)
---------------------------------------------------------
torch.Size([9, 3, 32, 32])
====================================
torch.Size([4, 4, 32, 32])
Process finished with exit code 0
3.2 stack 创建新维度
stack操作两个维度必须一致
import torch
# stack操作两个维度必须一致
a1 = torch.rand(4, 3, 32, 32)
a2 = torch.rand(4, 3, 32, 32)
print(torch.cat([a1, a2], dim=1).shape)
print('====================================')
print(torch.stack([a1, a2], dim=1).shape) # 各自创建一个新的维度。然后concat
a = torch.rand(32, 8)
b = torch.rand(32, 8)
print(torch.stack([a, b], dim=0).shape)
---------------------------------------------------------
torch.Size([4, 6, 32, 32])
====================================
torch.Size([4, 2, 3, 32, 32]) # 各自创建一个新的维度。然后concat
torch.Size([2, 32, 8])
Process finished with exit code 0
3.3 split 按长度拆分和 chunk 按数量拆分
.split (长度,dim) 第一参数表示拆分后的长度,第二个参数表示要拆分的维度。
import torch
c = torch.rand(2, 32, 8)
aa, bb = c.split([1, 1], dim=0)
print(aa.shape)
print(bb.shape)
print('====================================')
aa, bb = c.split(1, dim=0)
print(aa.shape)
print(bb.shape)
---------------------------------------------------------
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
====================================
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
Process finished with exit code 0
chunk(数量,dim)第一参数表示要拆分后的数量,第二个参数表示要拆分的维度。
import torch
c = torch.rand(8, 32, 8)
aa, bb = c.chunk(2, dim=0) # 第1个参数要拆分后的数量
print(aa.shape)
print(bb.shape)
---------------------------------------------------------
torch.Size([4, 32, 8])
torch.Size([4, 32, 8])
Process finished with exit code 0