函数定义reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: int)
介绍:该函数可以用非常直观的方式对张量进行一系列处理。
Parameters: 参数: tensor: 要处理的张量 pattern:维度变换 reduction: 要执行的操作,可以是('min', 'max', 'sum', 'mean', 'prod') axes_lengths: any additional specifications for dimensions
举例:
import torch
from einops import rearrange, reduce
a=torch.tensor([
[[1,1,1],
[2,2,2]],
[[3,3,3],
[4,4,4]]
],dtype=float)
1、张量创建好了,先尝试着任意做一个max操作。在此处t b c分别代表三个维度,转换为b c就意味着一个维度被抹消掉了,后面的max表示沿着t维度求最大值。
b=reduce(a, 't b c -> b c', 'max')
print(b)
结果:
tensor([[3., 3., 3.],
[4., 4., 4.]], dtype=torch.float64)
再尝试这次抹消掉第三个维度
b=reduce(a, 't b c -> t b', 'max')