深度学习处理数据中维度变换操作手册

好的,我们可以通过具体数字来更清楚地了解维度变换。假设以下变量有以下形状:

  • batch_size = 4
  • max_ent_cnt = 97
  • num_labels = 97
  • feature_size = 768
  • proto_dim = 768

初始形状如下:

  • ent_head: [97, 4, 768][max_ent_cnt, batch_size, feature_size]
  • ent_tail: [97, 4, 768][max_ent_cnt, batch_size, feature_size]
  • proto: [4, 97, 768][batch_size, num_labels, proto_dim]

Step-by-Step Transformation

  1. ent_headent_tail 变换
# Original shape: [max_ent_cnt, batch_size, feature_size] -> [97, 4, 768]
ent_head = ent_head.permute(1, 0, 2).unsqueeze(2).repeat(1, 1, self.max_ent_cnt, 1)
# After permute and unsqueeze: [4, 97, 1, 768]
# After repeat: [4, 97, 97, 768]

# Original shape: [max_ent_cnt, batch_size, feature_size] -> [97, 4, 768]
ent_tail = ent_tail.permute(1, 0, 2).unsqueeze(1).repeat(1, self.max_ent_cnt, 1, 1)
# After permute and unsqueeze: [4, 1, 97, 768]
# After repeat: [4, 97, 97, 768]

具体过程:

  • permute(1, 0, 2)[97, 4, 768] 变为 [4, 97, 768]
  • unsqueeze(2)[4, 97, 768] 变为 [4, 97, 1, 768](在第三维添加一个维度)
  • repeat(1, 1, self.max_ent_cnt, 1)[4, 97, 1, 768] 变为 [4, 97, 97, 768](在第三维重复 self.max_ent_cnt 次)

类似的步骤对 ent_tail 进行处理,只是 unsqueezerepeat 的维度不同。

  1. proto 变换
# Original shape: [batch_size, num_labels, proto_dim] -> [4, 97, 768]
proto = proto.unsqueeze(1).repeat(1, self.max_ent_cnt, 1, 1)
# After unsqueeze: [4, 1, 97, 768]
# After repeat: [4, 97, 97, 768]

具体过程:

  • unsqueeze(1)[4, 97, 768] 变为 [4, 1, 97, 768](在第一维添加一个维度)
  • repeat(1, self.max_ent_cnt, 1, 1)[4, 1, 97, 768] 变为 [4, 97, 97, 768](在第二维重复 self.max_ent_cnt 次)
  1. 调整 einsum 等式以匹配形状
logits = torch.einsum("xyz,bhtx,bcy,bhtz->bhtc", self.core_tensor, ent_head, proto.float(), ent_tail) + self.cls_bias

einsum 计算说明:

  • self.core_tensor 形状 [768, 768, 768] -> xyz
  • ent_head 形状 [4, 97, 97, 768] -> bhtx
  • proto 形状 [4, 97, 97, 768] -> bcy
  • ent_tail 形状 [4, 97, 97, 768] -> bhtz

einsum 将按照如下方式进行计算:

  • xx 相乘,结果的形状是 [4, 97, 97, 97]
  • yy 相乘,结果的形状是 [4, 97, 97, 97]
  • zz 相乘,结果的形状是 [4, 97, 97, 97]
  • 结果最终聚合为 [4, 97, 97, 97]

最终形状

  • logits 的最终形状为 [4, 97, 97, 97]

在使用 PyTorch 的 reshapeview 函数时,不会丢失数据。这两个函数的主要作用是改变张量的形状,而不会改变张量的数据内容。

示例代码

让我们通过一个具体的例子来说明这一点:

import torch

# 生成一个形状为 [1, 97] 的张量,使用从 0 到 96 的整数
tensor = torch.arange(97).reshape(1, 97)

# 打印生成的张量及其形状
print("Original tensor:")
print(tensor)
print("Shape:", tensor.shape)

# 将张量转换为形状 [97, 1]
reshaped_tensor = tensor.reshape(97, 1)

# 打印转换后的张量及其形状
print("\nReshaped tensor:")
print(reshaped_tensor)
print("Shape:", reshaped_tensor.shape)

输出

运行上述代码后,您会看到输出如下:

Original tensor:
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
         54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
         72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
         90, 91, 92, 93, 94, 95, 96]])
Shape: torch.Size([1, 97])

Reshaped tensor:
tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13],
        [14],
        [15],
        [16],
        [17],
        [18],
        [19],
        [20],
        [21],
        [22],
        [23],
        [24],
        [25],
        [26],
        [27],
        [28],
        [29],
        [30],
        [31],
        [32],
        [33],
        [34],
        [35],
        [36],
        [37],
        [38],
        [39],
        [40],
        [41],
        [42],
        [43],
        [44],
        [45],
        [46],
        [47],
        [48],
        [49],
        [50],
        [51],
        [52],
        [53],
        [54],
        [55],
        [56],
        [57],
        [58],
        [59],
        [60],
        [61],
        [62],
        [63],
        [64],
        [65],
        [66],
        [67],
        [68],
        [69],
        [70],
        [71],
        [72],
        [73],
        [74],
        [75],
        [76],
        [77],
        [78],
        [79],
        [80],
        [81],
        [82],
        [83],
        [84],
        [85],
        [86],
        [87],
        [88],
        [89],
        [90],
        [91],
        [92],
        [93],
        [94],
        [95],
        [96]])
Shape: torch.Size([97, 1])

解释

  • 原始张量:形状为 [1, 97],包含从 0 到 96 的整数。
  • 转换后的张量:形状为 [97, 1],数据内容保持不变,只是形状改变了。

结论

  • reshapeview 都不会丢失数据,它们只是改变了张量的形状。
  • 如果张量在内存中是连续的,view 可以使用;如果不确定,可以使用 reshape,它会自动处理非连续张量。

使用 reshapeview 改变张量形状时,数据内容不会被修改或丢失,因此可以安全地使用这些函数进行形状变换。

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值