np.tensordot 的理解和使用

本文深入解析Numpy中tensordot函数的使用方法,通过实例演示如何处理多维数组,尤其在卷积神经网络的卷积运算中,如何高效进行数组间的点乘操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Numpy是使用最广的科学计算库,对于多维数组的操作更是在实践中用的最多,而且也是比较困惑的地方,但是用好了事半功倍,今天讲一下numpy的 tensordot 的使用,这个函数在卷积神经网络的卷积中用到。

数组的基本属性

数组基本属性:维度、形状、strides(跨越数组各个维度所需要经过的字节数)、数组元素个数、元素占用字节数、数组占用空间,用以下例子说明:

>>> X = np.random.randint(0,9,(3,4,5))
>>> X
array([[[5, 1, 3, 6, 5],
        [5, 1, 8, 0, 5],
        [8, 5, 7, 8, 5],
        [8, 1, 5, 1, 4]],

       [[7, 7, 7, 7, 6],
        [0, 3, 4, 4, 6],
        [8, 4, 2, 1, 1],
        [6, 3, 4, 5, 4]],

       [[0, 2, 8, 0, 7],
        [6, 5, 8, 2, 2],
        [0, 1, 2, 3, 5],
        [7, 8, 7, 7, 6]]])
>>> X.ndim
3
>>> X.shape
(3, 4, 5)
>>> X.strides
(160, 40, 8)
>>> X.size
60
>>> X.itemsize
8
>>> X.nbytes
480
多维数组轴向取值

数组的取值看似简单但是在高纬度下,还是需要注意一下取法.
最原始取法,如取第一个元素

>>> X[0][0][0]
5

按轴取值则不同,取出来的值可能是数组,仍以上述为例,X.shape为(3,4,5),说明是3维数组,或者说有三个轴0,1,2. 第0轴上3个元素,第1轴上4个元素,第2轴上5个元素,如果要取轴上元素如何写?看以下例子。以下取第0轴第一个元素。

>>> X[0]
array([[5, 1, 3, 6, 5],
       [5, 1, 8, 0, 5],
       [8, 5, 7, 8, 5],
       [8, 1, 5, 1, 4]])
>>> X[1]
array([[7, 7, 7, 7, 6],
       [0, 3, 4, 4, 6],
       [8, 4, 2, 1, 1],
       [6, 3, 4, 5, 4]])
>>> X[2]
array([[0, 2, 8, 0, 7],
       [6, 5, 8, 2, 2],
       [0, 1, 2, 3, 5],
       [7, 8, 7, 7, 6]])
>>> X[4]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: index 4 is out of bounds for axis 0 with size 3

取1轴上的元素

>>> X[:,0,:]
array([[5, 1, 3, 6, 5],
      [7, 7, 7, 7, 6],
      [0, 2, 8, 0, 7]])
>>> X[:,5,:]
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
IndexError: index 5 is out of bounds for axis 1 with size 4

可以看到,按轴取出的元素实际上是一个子数组!

Tensordot的使用

进入正题, 运行如下代码:

>>> np.random.seed(10)
>>> A = np.random.randint(0,9,(3,4,5))
>>> B = np.random.randint(0,9,(4,5,2))
>>> np.tensordot(A, B, [(1,2), (0,1)])
array([[233,  89],
       [250, 234],
       [199, 244]])

解释:
(1,2) 是对A而言,不是取第1,2轴,而是除去1,2 轴,所以要取的是第0轴
(0,1) 是对B而言,不是取第0,1轴,而是除去0,1 轴,所以要取的是第2轴

以上两句是精华

A的形状是(3,4,5),第0轴上有3个元素,取法上面讲了;B的形状(4,5,2),第2轴上有2个元素,所以结果形状是(3,2)

Tensordot 的作用就是把取出的子数组做点乘操作,即是 np.sum(a*b) 操作。
我们来验证一下,上述的说法看结果形状(3,2)的第一个元素:A第0轴上第一个元素与B第2轴上的第一个元素点乘。

>>> A[0]
array([[4, 0, 1, 0, 1],
       [8, 0, 8, 6, 4],
       [3, 0, 4, 6, 8],
       [1, 8, 4, 1, 3]])
>>> B[:,:,0]
array([[8, 2, 5, 2, 3],
       [4, 0, 3, 2, 0],
       [0, 0, 1, 0, 5],
       [4, 6, 2, 3, 6]])
>>> np.sum(A[0]*B[:,:,0])
233

结果完全正确!就是这么简单,多说都是废话!

<think>嗯,用户想了解NumPy中dot方法的用法相关问题。首先,我需要回忆一下numpy.dot()的基本功能。它主要用于向量矩阵的点积运算。对于一维数组,计算的是内积;二维数组的话就是矩阵乘法。可能用户需要具体的例子来理解不同维度的输入会产生什么结果。 接下来,我得考虑用户可能的实际应用场景。比如在机器学习中,矩阵乘法很常见,或者在线性代数运算中处理向量矩阵。用户可能需要知道如何正确使用dot方法,避免维度不匹配的错误。常见的错误可能包括数组形状不符合点积要求,这时候需要提醒用户注意数组的维度。 然后,应该提供一些代码示例。比如两个一维数组的点积,二维数组的矩阵乘法,以及高维数组的用法。同时,需要说明np.dot@运算符的区别,以及何时使用哪一个更合适。可能用户在使用过程中会遇到性能问题,或者需要处理更复杂的高维数组情况,这时候可以提到np.tensordot作为补充。 另外,用户的问题可能涉及到点积的应用场景,比如神经网络中的权重输入的计算,图像处理中的卷积操作等。这些实际例子能帮助用户更好地理解dot方法的重要性。同时,注意引用用户提供的参考文献,比如引用[1]中提到的创建数组的方法,可以结合到示例代码中,增强回答的相关性。 最后,要检查是否有遗漏点,比如参数说明、常见错误处理,以及可能的相关问题扩展。比如性能优化方法,或者替代函数的使用场景。确保回答结构清晰,分点说明,并自然地加入引用标记。同时,生成的相关问题需要围绕dot方法展开,可能涉及性能、应用、替代方案等,帮助用户进一步学习。</think>### numpy.dot方法详解 NumPy的`dot`方法用于计算两个数组的点积,支持向量、矩阵高维张量的运算。以下是关键要点示例: 1. **基本定义** - 一维数组:计算向量内积,公式为$\sum_i a_i b_i$ - 二维数组:执行矩阵乘法,公式为$C_{ij} = \sum_k A_{ik}B_{kj}$ - 高维数组:对最后一个轴倒数第二个轴进行乘积求 2. **核心参数** ```python numpy.dot(a, b, out=None) ``` 参数说明:`a``b`必须是相同数据类型的数组(支持自动类型转换) 3. **示例代码** ```python import numpy as np # 引用[1]中的数组创建方法 # 一维向量 v1 = np.array([1,2,3]) v2 = np.array([4,5,6]) print(np.dot(v1, v2)) # 1*4 + 2*5 + 3*6 = 32 # 二维矩阵 m1 = np.array([[1,2],[3,4]]) m2 = np.array([[5,6],[7,8]]) print(np.dot(m1, m2)) # [[19 22] # [43 50]] # 高维数组 a = np.arange(3*4*5).reshape(3,4,5) b = np.arange(3*5*2).reshape(3,5,2) print(np.dot(a, b).shape) # (3,4,3,2) ``` 4. **常见问题处理** - **维度不匹配错误**:确保`a`的最后一个维度与`b`的倒数第二个维度相等 - **广播机制**:与`np.matmul`不同,`dot`不支持广播 - **性能优化**:对于大矩阵建议使用`@`运算符(Python 3.5+)或`np.matmul` 5. **典型应用场景** - 神经网络前向传播计算:$WX + b$ - 图像处理中的卷积运算 - 物理仿真中的张量运算 6. **替代方法对比** | 方法 | 特点 | 支持广播 | 高维处理 | |---------------|------------------------|----------|-----------------| | `np.dot` | 通用点积 | 否 | 特殊规则 | | `@`运算符 | 矩阵乘法专用 | 是 | 遵循广播规则 | | `np.tensordot`| 指定收缩轴 | 否 | 灵活指定维度 | [^1]: 引用[1]中展示了NumPy数组的基本创建方法,这是进行点积运算的前提条件
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

科技与文明

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值