import tensorflow as tf
import numpy as np
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'
一、数据类型
1.数值类型
# 标量
num_scalar = tf.constant(1)
print('标量:', num_scalar)
# 向量
num_vector = tf.constant([1, 2])
print('向量:',num_vector)
# 矩阵 2行2列
num_matrix = tf.constant([[1,2], [3,4]])
print('矩阵:',num_matrix)
# 张量
num_tensor = tf.constant([[[1], [2]], [[3], [4]]])
print('张量:',num_tensor)
标量: tf.Tensor(1, shape=(), dtype=int32)
向量: tf.Tensor([1 2], shape=(2,), dtype=int32)
矩阵: tf.Tensor(
[[1 2]
[3 4]], shape=(2, 2), dtype=int32)
张量: tf.Tensor(
[[[1]
[2]]
[[3]
[4]]], shape=(2, 2, 1), dtype=int32)
(1)查看属性
# 查看变量的类型
type(num_scalar)
tensorflow.python.framework.ops.EagerTensor
# 判断是否属于tensor
tf.is_tensor(num_matrix)
True
# 查看tenfor的形状
num_tensor.shape
TensorShape([2, 2, 1])
# 返回numpy.array类型的数据
num_tensor.numpy()
array([[[1],
[2]],
[[3],
[4]]])
(2)数值精度
常用的精度类型:tf.int16, tf.int32, tf.int64, tf.float16, tf.float32, tf.float64(tf.double)
对于大部分深度学习算法,一般使用 tf.int32, tf.float32 可满足运算精度要求,部分对精度要求较高的算法,如强化学习,可以选择使用 tf.int64, tf.float64 精度保存张量。
# 创建张量时,可以指定数值精度,如指定Π的精度
pi_32 = tf.constant(np.pi, dtype=tf.float32)
pi_32
pi_64 = tf.constant(np.pi, dtype=tf.double)
pi_64
<tf.Tensor: shape=(), dtype=float32, numpy=3.1415927>
<tf.Tensor: shape=(), dtype=float64, numpy=3.141592653589793>
# 读取数值精度
print(pi_32.dtype)
print(pi_64.dtype)
<dtype: 'float32'>
<dtype: 'float64'>
2.字符串类型
string_tensor = tf.constant('Hello, Tensorflow!')
string_tensor
<tf.Tensor: shape=(), dtype=string, numpy=b'Hello, Tensorflow!'>
工具: tf.strings提供了字符串处理工具
print('查看字符串长度:')
tf.strings.length(string_tensor)
print('划分字符串:')
tf.strings.split(string_tensor)
print('字符串全部小写:')
tf.strings.lower(string_tensor)
print('字符串全部大写:')
tf.strings.upper(string_tensor)
查看字符串长度:
<tf.Tensor: shape=(), dtype=int32, numpy=18>
划分字符串:
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'Hello,', b'Tensorflow!'], dtype=object)>
字符串全部小写:
<tf.Tensor: shape=(), dtype=string, numpy=b'hello, tensorflow!'>
字符串全部大写:
<tf.Tensor: shape=(), dtype=string, numpy=b'HELLO, TENSORFLOW!'>
3.布尔类型
bool_tensor = tf.constant(True)
bool_tensor
bool_tensors = tf.constant([True, False])
bool_tensors
<tf.Tensor: shape=(), dtype=bool, numpy=True>
<tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True, False])>
4.数据类型转换
# 精度转换:tf.cast函数
pi_32
print('提高数据精度:')
tf.cast(pi_32, tf.double)
pi_64
print('降低精度:')
tf.cast(pi_64, tf.float32)
<tf.Tensor: shape=(), dtype=float32, numpy=3.1415927>
提高数据精度:
<tf.Tensor: shape=(), dtype=float64, numpy=3.1415927410125732>
<tf.Tensor: shape=(), dtype=float64, numpy=3.141592653589793>
降低精度:
<tf.Tensor: shape=(), dtype=float32, numpy=3.1415927>
# 布尔型与数值转换: 在 TensorFlow 中,非 0 数字都视为 True
print('布尔型转数值型:')
num_bool = tf.cast(bool_tensors, tf.int32)
num_bool
print('数值型转布尔型:')
tf.cast(num_bool, tf.bool)
布尔型转数值型:
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 0])>
数值型转布尔型:
<tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True, False])>
5.待优化张量:需要计算梯度的张量
# 将普通张量转换为待优化张量
a = tf.constant([-1, 0, 1, 2])
aa = tf.Variable(a)
# 直接创建待优化张量
b = tf.Variable([-1, 0, 1, 2])
b
<tf.Variable 'Variable:0' shape=(4,) dtype=int32, numpy=array([-1, 0, 1, 2])>
# 待优化张量新增属性
# name属性用于命名计算图中的变量
aa.name
# trainable表征当前张量是否需要被优化,默认为True
aa.trainable
'Variable:0'
True
二、创建张量
# 1.从numpy或list对象创建张量
tf.constant(np.array([1,2,3]))
tf.convert_to_tensor(np.array([1,2,3]))
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3])>
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3])>
# 2.创建全0或全1张量
a = tf.zeros(shape=(2,3))
b = tf.ones(shape=3)
# 创建与已知张量形状相同的全0或全1张量
c = tf.zeros_like(a)
d = tf.ones_like(b)
# 3.创建自定义数值张量:tf.fill(shape, value)
tf.fill([2,3], 5)
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[5, 5, 5],
[5, 5, 5]])>
# 4.创建已知分布的张量
# 正态分布
tf.random.normal([2,3], mean=1, stddev=2)
# 均匀分布
tf.random.uniform([3,4], minval=0, maxval=10)
tf.random.uniform([3,4], minval=0, maxval=10, dtype=tf.int32)
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 2.0144641, -2.352346 , -1.1124156],
[ 0.7475734, 2.289733 , 1.5603683]], dtype=float32)>
<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[8.628323 , 9.460016 , 7.159786 , 6.278372 ],
[2.9533637, 7.196653 , 2.380352 , 3.0045676],
[2.2084749, 3.1557035, 1.7376423, 7.997342 ]], dtype=float32)>
<tf.Tensor: shape=(3, 4), dtype=int32, numpy=
array([[5, 7, 6, 6],
[7, 5, 7, 7],
[4, 7, 5, 8]])>
# 5.创建序列
tf.range(10)
# 设置步长
tf.range(10, delta=2)
# 设置起始值
tf.range(1, 10, delta=3)
<tf.Tensor: shape=(10,), dtype=int32, numpy=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])>
<tf.Tensor: shape=(5,), dtype=int32, numpy=array([0, 2, 4, 6, 8])>
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 4, 7])>
三、张量的应用举例
# 1.标量的典型应用:训练误差、准度性的度量
#随机模拟网络输出
out = tf.random.uniform([4,10])
# 随机构造样本真实标签
y = tf.constant([2,3,2,0])
# one-hot 编码
y = tf.one_hot(y, depth=10)
# 计算每个样本的 MSE
loss = tf.keras.losses.mse(y, out)
# 平均 MSE
loss = tf.reduce_mean(loss)
print(loss)
tf.Tensor(0.39109215, shape=(), dtype=float32)
# 2.向量的典型应用;偏置项张量b
fc = tf.keras.layers.Dense(3) # 创建一层 Wx+b,输出节点为 3
fc.build(input_shape=(2,4)) # 通过 build 函数创建 W,b 张量,输入节点为 4
fc.bias # 查看偏置
# fc.weights
<tf.Variable 'bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>
# 3.矩阵的典型应用:神经网络权重的设置
# 例一:
x = tf.random.uniform([2,4], minval=0, maxval=2, dtype=tf.int32)
x = tf.cast(x, tf.float32)
w = tf.ones([4,3]) # 定义 W 张量
b = tf.zeros([3]) # 定义 b 张量
o = x@w+b # X@W+b 运算
print('ouput:', o)
# 例二:
fc = tf.keras.layers.Dense(3) # 定义全连接层的输出节点为 3
fc.build(input_shape=(2,4)) # 定义全连接层的输入节点为 4
print('权重与偏置项', fc.kernel) # 查看权重W与偏置项b
ouput: tf.Tensor(
[[4. 4. 4.]
[2. 2. 2.]], shape=(2, 3), dtype=float32)
权重与偏置项 <tf.Variable 'kernel:0' shape=(4, 3) dtype=float32, numpy=
array([[ 0.39876783, 0.18011749, 0.34121704],
[-0.7244823 , 0.47297895, -0.1585576 ],
[-0.3187186 , 0.2042054 , -0.362638 ],
[ 0.13634324, 0.4525478 , -0.7070261 ]], dtype=float32)>
# 4.三维张量的应用:
# (1)序列信号 X = [b, sequence_len, feature_len], 分别表示序列信号的数量,时间维度上的采样点数,每个点的特征长度
# (2)自然语言处理中句子的表示:
# 自动加载 IMDB 电影评价数据集
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.imdb.load_data(num_words=10000)
print('句子原始形状:', x_train.shape)
# 将句子填充、截断为等长 80 个单词的句子
x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=80)
print('填充、截断句子形状:', x_train.shape)
# 创建词向量 Embedding 层类,将数字编码的单词转换为长度为 100 个词向量
embedding=tf.keras.layers.Embedding(10000, 100)
# 将数字编码的单词转换为词向量
out = embedding(x_train)
print('编码后句子形状:', out.shape)
句子原始形状: (25000,)
填充、截断句子形状: (25000, 80)
编码后句子形状: (25000, 80, 100)
# 5.四维张量的应用:卷积神经网络
# [b, h , w, c]分别表示输入数量,图片高度, 图片宽度,图的通道数
# 创建 32x32 的彩色图片输入,个数为 4
x = tf.random.normal([4,32,32,3])
# 创建卷积神经网络
layer = tf.keras.layers.Conv2D(64,kernel_size=(7,5))
# 前向计算
out = layer(x)
# 输出大小
out.shape
# 其中卷积核张量也是 4 维张量,可以通过 kernel 成员变量访问:
layer.kernel.shape
TensorShape([4, 26, 28, 64])
TensorShape([7, 5, 3, 64])
四、索引与切片
# X 为 4 张 32x32 大小的彩色图片
x = tf.random.normal([4,32,32,3])
索引
# 取第 1 张图片的数据:
x[0].shape
x[0]
TensorShape([32, 32, 3])
<tf.Tensor: shape=(32, 32, 3), dtype=float32, numpy=
array([[[-1.4066522 , -1.6301475 , -0.6626721 ],
[ 0.4622193 , 1.1874142 , -1.0525869 ],
[ 1.1119522 , -1.336452 , -0.36873984],
...,
[ 0.76140714, 0.7055328 , 0.60365176],
[-0.69791937, -0.21970169, 0.6287722 ],
[-0.7085722 , -0.7537329 , -1.1888791 ]],
[[ 0.05302444, 1.088168 , 1.5901968 ],
[-1.1256545 , -0.00637588, -0.476244 ],
[ 0.20962416, 0.1299456 , -0.23336564],
...,
[ 0.2610068 , 0.43658018, -0.7139866 ],
[ 0.12355808, 0.41040298, -2.1643631 ],
[ 1.0014896 , -0.39356777, -0.11626529]],
[[ 0.6359369 , 1.4638699 , -0.5852146 ],
[-1.2632309 , 0.6778309 , 1.769848 ],
[ 2.042887 , 0.10215978, 0.14071363],
...,
[-0.5496323 , 0.8159724 , -0.37511298],
[-0.58229995, 0.31804055, 0.18805313],
[ 0.19327746, -0.91908634, 0.7599561 ]],
...,
[[ 0.57062095, -1.0924951 , 1.1462951 ],
[-0.17186154, -0.6705764 , -0.28675246],
[ 1.0782579 , -0.65768707, 0.18618874],
...,
[-1.129991 , -0.9912013 , 0.34741646],
[ 0.9312622 , -1.6201673 , -0.8323024 ],
[ 1.1002243 , -0.8553034 , 0.0289506 ]],
[[ 0.44621316, 0.56181026, 0.52481025],
[-0.5058765 , -0.06046636, 0.09103815],
[ 0.8362112 , -0.01185856, -0.10881247],
...,
[-0.26027304, 0.20625828, 0.45409092],
[-0.9735298 , 0.24121265, 1.4221593 ],
[ 0.29431424, 0.27458662, 1.2655455 ]],
[[ 0.69613606, -0.60457057, -0.08587647],
[ 0.69580555, 0.7907061 , -0.61136305],
[ 0.03631109, 0.5971228 , -0.27422765],
...,
[-0.38135353, 0.556312 , -1.0376903 ],
[ 0.9262408 , -0.57603973, -0.5694517 ],
[ 0.27046308, 1.4378041 , 2.0465028 ]]], dtype=float32)>
# 取第 1 张图片的第 2 行:
x[0][1]
<tf.Tensor: shape=(32, 3), dtype=float32, numpy=
array([[ 0.05302444, 1.088168 , 1.5901968 ],
[-1.1256545 , -0.00637588, -0.476244 ],
[ 0.20962416, 0.1299456 , -0.23336564],
[ 0.18930745, -1.3828539 , 2.3443792 ],
[-0.82742244, 0.11228334, -2.3803499 ],
[-0.12717138, 0.64265406, -0.6452635 ],
[ 0.8297432 , 0.01237224, 0.8067096 ],
[-0.36701548, 1.8592048 , -1.8694289 ],
[ 1.8069752 , -0.10527242, 1.041227 ],
[-0.10494869, 1.8063564 , -1.35875 ],
[ 0.93015295, 0.64609325, 1.1821879 ],
[ 1.1264399 , -1.0496296 , -1.2273068 ],
[ 0.2710233 , -0.77052 , 1.2700952 ],
[ 0.60882455, -0.4986755 , 0.11282647],
[-0.17114055, -0.26709092, -0.3601687 ],
[ 1.1210955 , 0.6035169 , -0.50023425],
[ 1.1446507 , 0.30143014, 0.81295407],
[-0.03041968, 0.81120336, 0.69047695],
[-2.2849495 , 0.2028147 , -0.02289553],
[-1.0418856 , -0.37714505, -0.8413859 ],
[-0.97956747, 1.2297201 , -0.80886227],
[ 0.599066 , 0.50864047, -0.0591111 ],
[-0.1478625 , 0.7542056 , -0.5288424 ],
[-0.31700444, 0.660806 , -0.16487013],
[ 0.21907088, -0.12597667, -0.67567325],
[-0.6869729 , -0.65090865, -0.27894628],
[-0.66607517, -1.1051095 , -1.4386888 ],
[ 0.12348855, -0.29393324, -0.25880232],
[ 0.54234713, 0.27453753, 0.0945635 ],
[ 0.2610068 , 0.43658018, -0.7139866 ],
[ 0.12355808, 0.41040298, -2.1643631 ],
[ 1.0014896 , -0.39356777, -0.11626529]], dtype=float32)>
# 取第 1 张图片,第 2 行,第 3 列的像素:
x[0][1][2]
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 0.20962416, 0.1299456 , -0.23336564], dtype=float32)>
# 取第 3 张图片,第 2 行,第 1 列的像素,B 通道(第 2 个通道)颜色强度值:
x[2][1][0][1]
<tf.Tensor: shape=(), dtype=float32, numpy=-0.09115129>
# 取第 2 张图片,第 10 行,第 3 列:
x[1, 9, 2]
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 0.97449315, -0.95866084, 1.5910561 ], dtype=float32)>
切片
# 读取第 2,3 张图片
x[1:3]
<tf.Tensor: shape=(2, 32, 32, 3), dtype=float32, numpy=
array([[[[-1.31716 , 0.8927563 , -0.19094189],
[-0.733248 , -0.77084136, -1.0931517 ],
[-1.0336089 , 0.6535478 , 1.1548346 ],
...,
[ 0.30430856, -0.11049574, -0.7938342 ],
[-1.7981858 , -2.225117 , -0.83562917],
[ 0.93750566, -0.67847425, 0.9224359 ]],
[[ 2.499113 , 1.6312623 , 0.12778619],
[ 1.1692234 , 0.0189321 , -1.0309207 ],
[-0.62398046, -0.38072154, -1.1224715 ],
...,
[-0.32941985, 0.10051683, 1.836634 ],
[-0.7381755 , -1.4565924 , 0.70586956],
[ 1.6443965 , -0.58757645, -2.0862925 ]],
[[-0.04231573, -1.0543141 , -0.07294127],
[-1.3289794 , 1.7232455 , 0.34619242],
[ 0.430539 , 0.751165 , 1.1801926 ],
...,
[ 1.3924987 , 0.8397924 , 0.5397981 ],
[ 0.8138152 , 1.8297083 , 0.9749301 ],
[-0.52527463, 0.0285456 , -1.0811192 ]],
...,
[[ 1.3123864 , -0.59172046, -0.25115427],
[-0.5665263 , -0.4391703 , -0.10928307],
[ 0.07010391, -0.58571684, 1.4293606 ],
...,
[ 2.1118972 , 0.23430105, -0.44872212],
[ 0.32074732, -1.6758326 , -0.7835137 ],
[ 0.10189884, 0.3735951 , -0.2619103 ]],
[[-1.153347 , -1.3645796 , 1.0547861 ],
[-0.60928154, -1.3566399 , 0.2504541 ],
[ 0.5597972 , 1.2221434 , -0.18595591],
...,
[ 0.39691612, 2.0617461 , 0.03640852],
[-0.80052674, 0.13028495, 0.36886504],
[ 0.53274167, 1.8119245 , 0.2898678 ]],
[[-0.42815545, -0.4821952 , -0.01792378],
[ 0.7978528 , -1.4759514 , -0.34798223],
[ 1.7242622 , -0.295613 , 0.3123493 ],
...,
[ 2.0360522 , 0.56676894, -0.24085432],
[-1.8507787 , 0.24397703, 1.2036182 ],
[-0.72610736, 0.27573547, -0.34260753]]],
[[[ 1.6538881 , -1.2828442 , -0.03862259],
[ 0.6386152 , 1.1786485 , -0.4696664 ],
[-0.9385108 , 2.177739 , 0.27142107],
...,
[ 2.6009357 , -0.3140693 , 1.3819468 ],
[ 0.4681579 , -0.59527147, -0.7687865 ],
[ 0.7524163 , -0.85284793, -1.2497554 ]],
[[-1.0554148 , -0.09115129, -0.75751054],
[-1.0196933 , 1.114568 , -0.26536736],
[-0.33567974, -1.3310446 , -0.7331455 ],
...,
[ 0.7729232 , 1.0367365 , 0.30283388],
[ 2.2981248 , 1.1215141 , -0.05669821],
[ 0.00336637, -0.16251333, -2.7162955 ]],
[[-1.5069393 , -0.01386867, -0.8434442 ],
[-0.75924855, -2.37886 , 0.96467763],
[ 0.7158936 , -0.24094783, -1.7288772 ],
...,
[-0.29446223, -0.2361764 , -1.0333688 ],
[-0.4777188 , -1.6288629 , -0.03284854],
[ 1.4216906 , -1.92074 , -0.22618869]],
...,
[[-0.4792214 , 0.3826494 , 1.2373917 ],
[-0.6446881 , 0.7737799 , 0.10003889],
[ 0.6713453 , -2.136748 , -0.4338728 ],
...,
[-0.16307087, 0.357456 , -0.52207917],
[ 1.3242189 , -1.0513958 , 1.1873055 ],
[-0.36773506, -0.4142033 , -0.16103096]],
[[ 0.50397587, 0.9386607 , 2.8985884 ],
[-0.8921362 , -0.06021392, -0.19666769],
[ 0.4871909 , 0.18244128, -0.2612286 ],
...,
[-1.8799183 , 0.38841182, 3.0188472 ],
[ 0.1810747 , -0.61945873, -1.7169424 ],
[ 1.3666738 , 0.32145166, 0.8407051 ]],
[[ 0.06317475, 0.15326717, 0.1563328 ],
[ 0.01047049, 1.2195497 , -0.53379405],
[-0.50812846, 0.96157503, -0.20100513],
...,
[-0.3401454 , 0.9257043 , -1.1202906 ],
[ 0.96981984, -0.37571394, 0.90335006],
[-0.4275013 , -0.98839325, -0.3697131 ]]]], dtype=float32)>
# 读取第 1 张图片的所有行
x[0,::]
<tf.Tensor: shape=(32, 32, 3), dtype=float32, numpy=
array([[[-1.4066522 , -1.6301475 , -0.6626721 ],
[ 0.4622193 , 1.1874142 , -1.0525869 ],
[ 1.1119522 , -1.336452 , -0.36873984],
...,
[ 0.76140714, 0.7055328 , 0.60365176],
[-0.69791937, -0.21970169, 0.6287722 ],
[-0.7085722 , -0.7537329 , -1.1888791 ]],
[[ 0.05302444, 1.088168 , 1.5901968 ],
[-1.1256545 , -0.00637588, -0.476244 ],
[ 0.20962416, 0.1299456 , -0.23336564],
...,
[ 0.2610068 , 0.43658018, -0.7139866 ],
[ 0.12355808, 0.41040298, -2.1643631 ],
[ 1.0014896 , -0.39356777, -0.11626529]],
[[ 0.6359369 , 1.4638699 , -0.5852146 ],
[-1.2632309 , 0.6778309 , 1.769848 ],
[ 2.042887 , 0.10215978, 0.14071363],
...,
[-0.5496323 , 0.8159724 , -0.37511298],
[-0.58229995, 0.31804055, 0.18805313],
[ 0.19327746, -0.91908634, 0.7599561 ]],
...,
[[ 0.57062095, -1.0924951 , 1.1462951 ],
[-0.17186154, -0.6705764 , -0.28675246],
[ 1.0782579 , -0.65768707, 0.18618874],
...,
[-1.129991 , -0.9912013 , 0.34741646],
[ 0.9312622 , -1.6201673 , -0.8323024 ],
[ 1.1002243 , -0.8553034 , 0.0289506 ]],
[[ 0.44621316, 0.56181026, 0.52481025],
[-0.5058765 , -0.06046636, 0.09103815],
[ 0.8362112 , -0.01185856, -0.10881247],
...,
[-0.26027304, 0.20625828, 0.45409092],
[-0.9735298 , 0.24121265, 1.4221593 ],
[ 0.29431424, 0.27458662, 1.2655455 ]],
[[ 0.69613606, -0.60457057, -0.08587647],
[ 0.69580555, 0.7907061 , -0.61136305],
[ 0.03631109, 0.5971228 , -0.27422765],
...,
[-0.38135353, 0.556312 , -1.0376903 ],
[ 0.9262408 , -0.57603973, -0.5694517 ],
[ 0.27046308, 1.4378041 , 2.0465028 ]]], dtype=float32)>
# 取所有图片,隔行采样,隔列采样,所有通道信息,相当于在图片的高宽各缩放至原来的 50%
x[:, ::2, ::2, :]
<tf.Tensor: shape=(4, 16, 16, 3), dtype=float32, numpy=
array([[[[-1.4066522 , -1.6301475 , -0.6626721 ],
[ 1.1119522 , -1.336452 , -0.36873984],
[-1.1216803 , 1.4819081 , 0.19174473],
...,
[ 0.9616238 , -0.127504 , -0.05364741],
[-1.502996 , 1.4380702 , 0.973189 ],
[-0.69791937, -0.21970169, 0.6287722 ]],
[[ 0.6359369 , 1.4638699 , -0.5852146 ],
[ 2.042887 , 0.10215978, 0.14071363],
[-1.7389532 , 1.0412104 , -0.3376094 ],
...,
[ 0.8654285 , -1.2371039 , -0.29329225],
[-0.6470304 , 0.1699918 , -0.4201792 ],
[-0.58229995, 0.31804055, 0.18805313]],
[[ 1.514192 , 1.0639831 , -1.6242381 ],
[ 1.6208409 , 0.61728925, 1.6865913 ],
[-1.6289519 , -0.09172544, 0.41329512],
...,
[ 0.47612637, -0.5496553 , 0.19886701],
[ 0.9083845 , -0.46511444, -0.23636743],
[ 1.0820898 , 0.3515731 , 1.3956282 ]],
...,
[[-0.9186473 , 0.48401567, -0.20875901],
[-0.53344274, -0.52955556, 1.8328873 ],
[-0.7733719 , 0.29834586, 1.0668224 ],
...,
[-1.3580707 , -0.6614838 , 1.238765 ],
[ 0.4973609 , 0.01429008, -1.6891569 ],
[ 0.17739452, -0.7004075 , 1.2233291 ]],
[[-0.6449783 , -1.1857752 , -0.95025533],
[-0.04647965, 1.6509633 , 0.70183164],
[-0.5234073 , -0.40614322, -0.44905052],
...,
[ 1.1192619 , -0.35078907, -0.17611592],
[-1.5070972 , 0.0043169 , -1.2150966 ],
[-0.28248122, -0.14355616, -0.15214995]],
[[ 0.44621316, 0.56181026, 0.52481025],
[ 0.8362112 , -0.01185856, -0.10881247],
[ 0.04107153, -0.09329593, -0.30412927],
...,
[ 0.49422017, -1.3666263 , -1.2329497 ],
[-0.30040884, 0.3625095 , 0.40492785],
[-0.9735298 , 0.24121265, 1.4221593 ]]],
[[[-1.31716 , 0.8927563 , -0.19094189],
[-1.0336089 , 0.6535478 , 1.1548346 ],
[ 2.1178985 , 0.44546792, 1.4923142 ],
...,
[-1.2079732 , 0.21183792, -0.47470355],
[ 1.8716471 , 0.38424423, 0.45972088],
[-1.7981858 , -2.225117 , -0.83562917]],
[[-0.04231573, -1.0543141 , -0.07294127],
[ 0.430539 , 0.751165 , 1.1801926 ],
[-1.2387655 , 1.2330005 , -1.4386145 ],
...,
[-3.191042 , -0.1695644 , -2.7725818 ],
[-1.6863043 , -1.6077073 , 1.5912145 ],
[ 0.8138152 , 1.8297083 , 0.9749301 ]],
[[-1.1073742 , 0.30253428, -0.5618196 ],
[-1.1719611 , -1.6837698 , -0.5742327 ],
[ 0.35131663, 0.00361567, -0.7844822 ],
...,
[-1.2514129 , -0.22500975, 0.41273192],
[ 1.0914025 , -1.1279705 , 0.8975912 ],
[ 0.56316924, -0.00613224, -0.9379784 ]],
...,
[[ 0.1403622 , -0.37658456, -0.5574308 ],
[-0.88762856, -0.20102285, -0.49638802],
[-1.1747687 , 0.64913094, 0.07465671],
...,
[-0.2831462 , 1.4277062 , -0.6432949 ],
[ 0.8692745 , -0.9615009 , -0.94364834],
[ 1.0578264 , -0.9055548 , -1.3884908 ]],
[[ 1.6147898 , -0.76326054, 1.0747921 ],
[ 1.199795 , -0.90053177, -0.8892275 ],
[ 1.5152236 , 1.2341576 , -1.6124233 ],
...,
[-1.4800897 , -0.18380113, -0.88795835],
[ 1.0938817 , 0.03855901, -0.4113766 ],
[-0.34892878, -0.20459495, 0.5686236 ]],
[[-1.153347 , -1.3645796 , 1.0547861 ],
[ 0.5597972 , 1.2221434 , -0.18595591],
[-1.0038961 , -1.0488672 , -1.3085092 ],
...,
[-0.04479567, -1.3941933 , -0.2229883 ],
[-0.52149624, 0.6197523 , 0.5761725 ],
[-0.80052674, 0.13028495, 0.36886504]]],
[[[ 1.6538881 , -1.2828442 , -0.03862259],
[-0.9385108 , 2.177739 , 0.27142107],
[ 0.4192225 , 1.7302549 , -0.19198844],
...,
[ 0.57185555, -0.97126657, -1.3499206 ],
[-0.35321546, -0.88202614, -0.8655192 ],
[ 0.4681579 , -0.59527147, -0.7687865 ]],
[[-1.5069393 , -0.01386867, -0.8434442 ],
[ 0.7158936 , -0.24094783, -1.7288772 ],
[-0.41253087, 0.5482663 , 0.2615464 ],
...,
[-1.0946057 , -0.47949368, 2.71418 ],
[ 0.5390706 , -0.7334311 , -0.283474 ],
[-0.4777188 , -1.6288629 , -0.03284854]],
[[ 1.0887016 , -0.7577311 , -0.537817 ],
[-1.0302697 , -0.19131397, -0.79996824],
[ 1.3039442 , 0.31810725, 0.7159762 ],
...,
[ 0.4652573 , -0.51346284, 0.6437446 ],
[-0.8540984 , 0.85202444, 1.338621 ],
[-0.9454326 , 0.34911907, -0.5797377 ]],
...,
[[ 0.08243454, 1.1988658 , 1.0254418 ],
[ 0.63155466, 1.0079315 , 0.5422959 ],
[-0.21402298, 0.70890206, 1.554218 ],
...,
[ 0.8999925 , -0.21716611, -0.31117567],
[-0.52009517, 1.9389265 , 0.6983924 ],
[-1.1133703 , -0.21071957, -0.6773069 ]],
[[-0.93739384, -0.00981612, -0.9172169 ],
[-0.38341337, -0.1792818 , -1.0535135 ],
[-0.22962493, 0.41208678, -0.93959725],
...,
[ 0.5471708 , -0.05138006, 0.31088695],
[-1.4627103 , -1.9810112 , 0.9871064 ],
[-0.5954719 , 0.82192105, -0.61218923]],
[[ 0.50397587, 0.9386607 , 2.8985884 ],
[ 0.4871909 , 0.18244128, -0.2612286 ],
[ 2.1795106 , -0.763841 , -1.3280549 ],
...,
[-0.25995788, 0.56037647, -0.5507551 ],
[ 0.8771116 , -2.09035 , 1.101149 ],
[ 0.1810747 , -0.61945873, -1.7169424 ]]],
[[[ 0.796133 , 0.5410011 , 0.9382214 ],
[ 0.41060907, -0.20038849, 0.6409274 ],
[ 0.14664018, -0.48961687, -2.7682152 ],
...,
[ 0.5555239 , 0.07243114, -2.026096 ],
[ 0.3944784 , -1.8553514 , -0.24285497],
[-0.9552354 , -1.3100492 , 0.29391095]],
[[ 1.8699164 , -0.73711467, 1.5590305 ],
[-1.597196 , 0.881057 , 0.47084296],
[-0.27228594, -0.8634398 , 1.4645449 ],
...,
[-0.72307605, 0.15268996, -1.089991 ],
[ 1.0803598 , -1.3153291 , 0.05169139],
[ 0.520886 , 1.0219852 , 0.9212156 ]],
[[ 2.4779575 , -0.01737644, -0.5229786 ],
[-0.43438995, -0.86700696, -0.6404235 ],
[ 0.08717794, 1.0277884 , 0.10849019],
...,
[-0.1702712 , -0.98193866, 2.5673513 ],
[ 0.44088537, 1.6612606 , -1.6000991 ],
[-0.19167791, -1.3604112 , -1.2433398 ]],
...,
[[-0.5087639 , 0.8378956 , 1.2819554 ],
[ 0.649141 , -0.521055 , -0.4728266 ],
[-1.4423455 , -0.10846883, -0.5847404 ],
...,
[ 1.3345126 , 0.5776808 , 1.7416115 ],
[ 0.22405353, 0.40215632, -0.13983175],
[-0.80241174, -0.48636377, -0.5027973 ]],
[[-0.97974837, -2.1886742 , 0.04100842],
[ 0.7036939 , -0.6549178 , -0.00676515],
[ 0.6745989 , -2.3424578 , 0.15672654],
...,
[ 0.22248814, 0.01724804, -1.9109946 ],
[-2.0762608 , 0.2616049 , -0.36650392],
[-0.60376364, 1.0571078 , -0.2624517 ]],
[[-0.5025821 , -1.7865878 , -0.8327457 ],
[ 0.5313025 , 0.3924203 , -0.47739467],
[-0.7475535 , 0.2349816 , 0.19985357],
...,
[-0.06519561, -1.7647893 , -0.03842835],
[ 0.09419745, 0.5742389 , -1.2229947 ],
[ 0.78169495, 0.33737284, 0.5514 ]]]], dtype=float32)>
# 逆序读取
y = tf.range(9)
y[8:0:-1]
<tf.Tensor: shape=(8,), dtype=int32, numpy=array([8, 7, 6, 5, 4, 3, 2, 1])>
# 逆序间隔采样
y[::-2]
<tf.Tensor: shape=(5,), dtype=int32, numpy=array([8, 6, 4, 2, 0])>
# 读取每张图片的所有通道,其中行按着逆序隔行采样,列按着逆序隔行采样:
x[:, ::-2, ::-2]
<tf.Tensor: shape=(4, 16, 16, 3), dtype=float32, numpy=
array([[[[ 2.70463079e-01, 1.43780410e+00, 2.04650283e+00],
[-3.81353527e-01, 5.56312025e-01, -1.03769028e+00],
[ 5.33211112e-01, 3.10682386e-01, 2.25513265e-01],
...,
[-1.31894922e+00, 6.72469199e-01, -4.24834251e-01],
[-3.99455070e-01, 1.13034093e+00, 9.75108087e-01],
[ 6.95805550e-01, 7.90706098e-01, -6.11363053e-01]],
[[ 1.10022426e+00, -8.55303407e-01, 2.89506000e-02],
[-1.12999105e+00, -9.91201282e-01, 3.47416461e-01],
[ 1.42449439e-01, -6.43958688e-01, -1.17649388e+00],
...,
[-6.41264737e-01, 6.58568740e-01, -2.83147812e-01],
[ 1.38702735e-01, 8.82205665e-01, 8.34312618e-01],
[-1.71861544e-01, -6.70576394e-01, -2.86752462e-01]],
[[ 1.61077237e+00, 1.47547674e+00, 1.15066981e+00],
[ 5.25802255e-01, -9.59300995e-01, 3.04066658e-01],
[-1.17115307e+00, 9.95538235e-02, 8.75156969e-02],
...,
[-8.60012919e-02, 1.82863545e+00, -4.53949630e-01],
[-2.08558694e-01, 1.69803965e+00, 5.54584622e-01],
[ 1.54844689e+00, -4.60885614e-01, 9.74037468e-01]],
...,
[[ 2.09663808e-01, 5.03294021e-02, -2.39943981e+00],
[-4.78137374e-01, -8.00408661e-01, 1.40211105e+00],
[-4.69818562e-02, 1.45516551e+00, -2.26509511e-01],
...,
[ 7.13527143e-01, -1.70955324e+00, -1.20516419e+00],
[-2.56722748e-01, -3.58597577e-01, -4.07044530e-01],
[-2.27065563e+00, -9.72444355e-01, -3.04375201e-01]],
[[ 1.24975836e+00, -2.18750909e-01, 9.84029830e-01],
[ 8.56332004e-01, -2.98688143e-01, -1.12196422e+00],
[ 2.38927782e-01, 3.77413809e-01, 1.92893334e-02],
...,
[ 5.45535684e-01, -1.14007747e+00, 1.54953802e+00],
[ 4.77895200e-01, -7.00159848e-01, 1.64433730e+00],
[ 1.99938035e+00, -1.29522908e+00, 2.33394933e+00]],
[[ 1.00148964e+00, -3.93567771e-01, -1.16265289e-01],
[ 2.61006802e-01, 4.36580181e-01, -7.13986576e-01],
[ 1.23488545e-01, -2.93933243e-01, -2.58802325e-01],
...,
[-1.27171382e-01, 6.42654061e-01, -6.45263493e-01],
[ 1.89307451e-01, -1.38285387e+00, 2.34437919e+00],
[-1.12565446e+00, -6.37588464e-03, -4.76244003e-01]]],
[[[-7.26107359e-01, 2.75735468e-01, -3.42607528e-01],
[ 2.03605223e+00, 5.66768944e-01, -2.40854323e-01],
[-7.50414968e-01, 6.54022247e-02, 1.16425622e+00],
...,
[-6.79210007e-01, 3.59585524e-01, 3.78254205e-01],
[ 3.92789185e-01, 1.85335711e-01, 1.56548560e+00],
[ 7.97852814e-01, -1.47595143e+00, -3.47982228e-01]],
[[ 1.01898842e-01, 3.73595089e-01, -2.61910290e-01],
[ 2.11189723e+00, 2.34301046e-01, -4.48722124e-01],
[ 4.13936704e-01, -4.67168331e-01, -5.99880695e-01],
...,
[ 4.64589655e-01, -4.13822114e-01, 9.78477240e-01],
[ 6.55454457e-01, 1.26392174e+00, -1.17857707e+00],
[-5.66526294e-01, -4.39170301e-01, -1.09283067e-01]],
[[ 6.91863060e-01, -2.36324593e-01, -5.28543532e-01],
[-7.22647667e-01, -2.03845069e-01, 1.39277232e+00],
[-4.04149055e-01, 1.64806932e-01, -4.09776747e-01],
...,
[-1.54772982e-01, 7.30288863e-01, -3.09380054e-01],
[ 3.57673019e-01, 4.36691731e-01, -4.12629247e-01],
[-3.86441737e-01, -2.03366137e+00, 1.29613638e+00]],
...,
[[-1.55937403e-01, -3.73228174e-03, 5.61697364e-01],
[ 5.48917413e-01, -1.29789817e+00, -1.20505917e+00],
[ 8.88521254e-01, -2.13731945e-01, 1.21868335e-01],
...,
[-3.20012212e-01, -1.07445729e+00, -2.69981891e-01],
[ 2.15791321e+00, -1.41033661e+00, 1.41868103e+00],
[ 5.08717120e-01, -1.54603338e+00, -1.61875099e-01]],
[[-4.50789124e-01, 7.74175227e-01, 5.72529621e-02],
[-5.89209199e-01, -7.64407158e-01, -1.09742761e+00],
[ 1.47474158e+00, -6.16571426e-01, -4.08598870e-01],
...,
[ 1.52155185e+00, 4.03917313e-01, 9.51666415e-01],
[-1.44588232e-01, -9.81349219e-03, 8.75104964e-01],
[-1.25682402e+00, 6.22186244e-01, -8.84938061e-01]],
[[ 1.64439654e+00, -5.87576449e-01, -2.08629251e+00],
[-3.29419851e-01, 1.00516833e-01, 1.83663404e+00],
[ 2.32498741e+00, -2.54601419e-01, 4.45122093e-01],
...,
[-1.80637136e-01, 6.99466169e-01, 1.13905704e+00],
[ 3.61730218e-01, -4.99152839e-01, -9.73379612e-01],
[ 1.16922343e+00, 1.89320967e-02, -1.03092074e+00]]],
[[[-4.27501291e-01, -9.88393247e-01, -3.69713098e-01],
[-3.40145409e-01, 9.25704300e-01, -1.12029064e+00],
[ 8.14169586e-01, 6.05460405e-01, 2.60386181e+00],
...,
[ 3.99615139e-01, 8.46545994e-01, 6.75601780e-01],
[-6.43654764e-02, 1.81750131e+00, -3.27244759e-01],
[ 1.04704909e-02, 1.21954966e+00, -5.33794045e-01]],
[[-3.67735058e-01, -4.14203286e-01, -1.61030963e-01],
[-1.63070872e-01, 3.57455999e-01, -5.22079170e-01],
[-1.51063931e+00, -5.99683881e-01, -3.53558123e-01],
...,
[ 1.12916040e+00, -4.65052277e-01, 9.40200329e-01],
[-2.14853287e-01, 1.43285513e+00, 8.45519543e-01],
[-6.44688129e-01, 7.73779929e-01, 1.00038894e-01]],
[[-1.43473673e+00, 1.14696689e-01, -7.76587903e-01],
[-6.79917216e-01, 6.04902387e-01, -1.90379530e-01],
[-2.54757714e+00, -5.84724963e-01, 1.33324730e+00],
...,
[-1.21236987e-01, -1.77558792e+00, -1.57310629e+00],
[ 2.36512259e-01, 4.44737852e-01, -2.10810041e+00],
[-8.26636851e-01, 1.05821848e+00, -7.64270127e-01]],
...,
[[ 2.52616316e-01, -6.08234167e-01, -5.97587347e-01],
[ 7.67554343e-01, -1.60571194e+00, -7.14286208e-01],
[ 6.90349340e-02, -5.88057041e-01, -6.89475596e-01],
...,
[-6.25573397e-01, -6.76500872e-02, -1.19496965e+00],
[ 3.81320029e-01, 8.77332389e-01, 2.21356135e-02],
[ 8.41787027e-04, 2.19724393e+00, -7.00479388e-01]],
[[-4.44216371e-01, 1.18845844e+00, 1.44990933e+00],
[-5.47546744e-01, 1.59971833e+00, -6.22707093e-03],
[ 5.15164100e-02, -6.15957201e-01, -3.54217738e-02],
...,
[ 2.54230648e-01, 4.45920741e-04, 3.17009151e-01],
[ 1.03450857e-01, 1.00563419e+00, 1.39238334e+00],
[ 3.25153410e-01, -1.88359126e-01, -2.31335235e+00]],
[[ 3.36636533e-03, -1.62513331e-01, -2.71629548e+00],
[ 7.72923172e-01, 1.03673649e+00, 3.02833885e-01],
[ 1.23411253e-01, -9.16424617e-02, -5.28129518e-01],
...,
[ 1.77046025e+00, 3.56886685e-01, 6.95330650e-02],
[ 1.12724043e-01, 3.19157867e-03, -1.71475136e+00],
[-1.01969326e+00, 1.11456800e+00, -2.65367359e-01]]],
[[[-1.15891248e-01, -1.41290545e-01, 7.66257346e-02],
[ 2.44068727e-01, 1.61916220e+00, 4.72523719e-01],
[ 3.12106758e-01, 1.61140931e+00, 9.27360475e-01],
...,
[-3.70257556e-01, -1.31743479e+00, -1.36546338e+00],
[ 2.73258053e-02, -2.65831351e-01, -4.82090898e-02],
[-5.41559577e-01, -3.85325640e-01, -2.57567883e+00]],
[[ 1.45899221e-01, -5.85888699e-02, 1.57103166e-01],
[-2.42469311e+00, -9.04979110e-01, -1.10454786e+00],
[ 1.96863651e+00, -4.83983159e-02, 1.13655746e+00],
...,
[ 2.83160806e+00, -8.30558240e-01, -1.90577090e+00],
[-1.14942360e+00, 6.83438405e-02, -1.07411752e-02],
[-3.82427797e-02, 2.01849031e+00, -2.47480541e-01]],
[[-2.42553130e-01, 2.61052269e-02, 1.20750403e+00],
[ 2.47576356e-01, 1.12260306e+00, 3.69094312e-01],
[ 1.61395982e-01, -8.77945542e-01, -9.10170496e-01],
...,
[-1.51026994e-01, -4.52619284e-01, -7.39579916e-01],
[ 8.30707669e-01, 7.73450017e-01, -2.06688240e-01],
[-1.13437545e+00, -1.92892087e+00, 2.31980339e-01]],
...,
[[ 1.37320280e+00, -6.70619905e-01, 5.35866261e-01],
[-1.09410036e+00, 6.84671581e-01, -1.73805761e+00],
[ 1.18337500e+00, -2.54578614e+00, -1.66087568e-01],
...,
[ 3.93906295e-01, -6.10780060e-01, 1.77272546e+00],
[-2.91264564e-01, 7.29625762e-01, 6.12191856e-01],
[-1.75670922e-01, 1.71077716e+00, 2.27649868e-01]],
[[ 2.04488921e+00, 9.83502865e-01, 5.24477720e-01],
[-1.39695513e+00, 1.16861604e-01, -9.29624975e-01],
[ 6.60978496e-01, -4.00158763e-01, 9.86123011e-02],
...,
[ 2.79058069e-01, 3.78170907e-01, 1.68269187e-01],
[-9.26116526e-01, 3.31941932e-01, -1.72549598e-02],
[-9.71043169e-01, -2.22090840e+00, -2.24909329e+00]],
[[ 8.13005805e-01, -2.22718850e-01, -1.81065726e+00],
[ 1.96232474e+00, 8.83163810e-01, 1.72268629e-01],
[ 3.68520975e-01, 1.91279411e+00, 1.43520033e+00],
...,
[ 1.23894322e+00, 9.32551265e-01, -2.03992203e-01],
[-7.83080101e-01, 1.52017325e-01, 2.69884020e-01],
[-1.22491968e+00, -2.73782104e-01, -1.00873554e+00]]]],
dtype=float32)>
# 读取第 1-2 张图片的 G/B 通道数据:
x[0:2,...,1:]
<tf.Tensor: shape=(2, 32, 32, 2), dtype=float32, numpy=
array([[[[-1.6301475 , -0.6626721 ],
[ 1.1874142 , -1.0525869 ],
[-1.336452 , -0.36873984],
...,
[ 0.7055328 , 0.60365176],
[-0.21970169, 0.6287722 ],
[-0.7537329 , -1.1888791 ]],
[[ 1.088168 , 1.5901968 ],
[-0.00637588, -0.476244 ],
[ 0.1299456 , -0.23336564],
...,
[ 0.43658018, -0.7139866 ],
[ 0.41040298, -2.1643631 ],
[-0.39356777, -0.11626529]],
[[ 1.4638699 , -0.5852146 ],
[ 0.6778309 , 1.769848 ],
[ 0.10215978, 0.14071363],
...,
[ 0.8159724 , -0.37511298],
[ 0.31804055, 0.18805313],
[-0.91908634, 0.7599561 ]],
...,
[[-1.0924951 , 1.1462951 ],
[-0.6705764 , -0.28675246],
[-0.65768707, 0.18618874],
...,
[-0.9912013 , 0.34741646],
[-1.6201673 , -0.8323024 ],
[-0.8553034 , 0.0289506 ]],
[[ 0.56181026, 0.52481025],
[-0.06046636, 0.09103815],
[-0.01185856, -0.10881247],
...,
[ 0.20625828, 0.45409092],
[ 0.24121265, 1.4221593 ],
[ 0.27458662, 1.2655455 ]],
[[-0.60457057, -0.08587647],
[ 0.7907061 , -0.61136305],
[ 0.5971228 , -0.27422765],
...,
[ 0.556312 , -1.0376903 ],
[-0.57603973, -0.5694517 ],
[ 1.4378041 , 2.0465028 ]]],
[[[ 0.8927563 , -0.19094189],
[-0.77084136, -1.0931517 ],
[ 0.6535478 , 1.1548346 ],
...,
[-0.11049574, -0.7938342 ],
[-2.225117 , -0.83562917],
[-0.67847425, 0.9224359 ]],
[[ 1.6312623 , 0.12778619],
[ 0.0189321 , -1.0309207 ],
[-0.38072154, -1.1224715 ],
...,
[ 0.10051683, 1.836634 ],
[-1.4565924 , 0.70586956],
[-0.58757645, -2.0862925 ]],
[[-1.0543141 , -0.07294127],
[ 1.7232455 , 0.34619242],
[ 0.751165 , 1.1801926 ],
...,
[ 0.8397924 , 0.5397981 ],
[ 1.8297083 , 0.9749301 ],
[ 0.0285456 , -1.0811192 ]],
...,
[[-0.59172046, -0.25115427],
[-0.4391703 , -0.10928307],
[-0.58571684, 1.4293606 ],
...,
[ 0.23430105, -0.44872212],
[-1.6758326 , -0.7835137 ],
[ 0.3735951 , -0.2619103 ]],
[[-1.3645796 , 1.0547861 ],
[-1.3566399 , 0.2504541 ],
[ 1.2221434 , -0.18595591],
...,
[ 2.0617461 , 0.03640852],
[ 0.13028495, 0.36886504],
[ 1.8119245 , 0.2898678 ]],
[[-0.4821952 , -0.01792378],
[-1.4759514 , -0.34798223],
[-0.295613 , 0.3123493 ],
...,
[ 0.56676894, -0.24085432],
[ 0.24397703, 1.2036182 ],
[ 0.27573547, -0.34260753]]]], dtype=float32)>
# 读取最后 2 张图片:
x[2:,...]
<tf.Tensor: shape=(2, 32, 32, 3), dtype=float32, numpy=
array([[[[ 1.6538881 , -1.2828442 , -0.03862259],
[ 0.6386152 , 1.1786485 , -0.4696664 ],
[-0.9385108 , 2.177739 , 0.27142107],
...,
[ 2.6009357 , -0.3140693 , 1.3819468 ],
[ 0.4681579 , -0.59527147, -0.7687865 ],
[ 0.7524163 , -0.85284793, -1.2497554 ]],
[[-1.0554148 , -0.09115129, -0.75751054],
[-1.0196933 , 1.114568 , -0.26536736],
[-0.33567974, -1.3310446 , -0.7331455 ],
...,
[ 0.7729232 , 1.0367365 , 0.30283388],
[ 2.2981248 , 1.1215141 , -0.05669821],
[ 0.00336637, -0.16251333, -2.7162955 ]],
[[-1.5069393 , -0.01386867, -0.8434442 ],
[-0.75924855, -2.37886 , 0.96467763],
[ 0.7158936 , -0.24094783, -1.7288772 ],
...,
[-0.29446223, -0.2361764 , -1.0333688 ],
[-0.4777188 , -1.6288629 , -0.03284854],
[ 1.4216906 , -1.92074 , -0.22618869]],
...,
[[-0.4792214 , 0.3826494 , 1.2373917 ],
[-0.6446881 , 0.7737799 , 0.10003889],
[ 0.6713453 , -2.136748 , -0.4338728 ],
...,
[-0.16307087, 0.357456 , -0.52207917],
[ 1.3242189 , -1.0513958 , 1.1873055 ],
[-0.36773506, -0.4142033 , -0.16103096]],
[[ 0.50397587, 0.9386607 , 2.8985884 ],
[-0.8921362 , -0.06021392, -0.19666769],
[ 0.4871909 , 0.18244128, -0.2612286 ],
...,
[-1.8799183 , 0.38841182, 3.0188472 ],
[ 0.1810747 , -0.61945873, -1.7169424 ],
[ 1.3666738 , 0.32145166, 0.8407051 ]],
[[ 0.06317475, 0.15326717, 0.1563328 ],
[ 0.01047049, 1.2195497 , -0.53379405],
[-0.50812846, 0.96157503, -0.20100513],
...,
[-0.3401454 , 0.9257043 , -1.1202906 ],
[ 0.96981984, -0.37571394, 0.90335006],
[-0.4275013 , -0.98839325, -0.3697131 ]]],
[[[ 0.796133 , 0.5410011 , 0.9382214 ],
[ 1.627531 , 1.4550391 , -1.0155681 ],
[ 0.41060907, -0.20038849, 0.6409274 ],
...,
[-1.4245912 , -1.0895957 , 0.6149582 ],
[-0.9552354 , -1.3100492 , 0.29391095],
[ 0.2587053 , 1.6795714 , 0.97186834]],
[[-1.6926944 , -1.088903 , -1.6040022 ],
[-1.2249197 , -0.2737821 , -1.0087355 ],
[ 0.10098427, 0.02635551, -1.5736842 ],
...,
[ 1.9623247 , 0.8831638 , 0.17226863],
[ 0.6759971 , -0.11783656, 0.9076593 ],
[ 0.8130058 , -0.22271885, -1.8106573 ]],
[[ 1.8699164 , -0.73711467, 1.5590305 ],
[ 0.00324396, 1.3924849 , -1.1840723 ],
[-1.597196 , 0.881057 , 0.47084296],
...,
[-1.7410468 , 0.35400176, 0.835881 ],
[ 0.520886 , 1.0219852 , 0.9212156 ],
[ 1.3913602 , -0.14510114, 0.38555175]],
...,
[[ 0.33566687, -2.0631125 , 0.6803144 ],
[-0.03824278, 2.0184903 , -0.24748054],
[-0.38573766, -0.22726655, 0.38307485],
...,
[-2.424693 , -0.9049791 , -1.1045479 ],
[-0.33129185, 1.0232396 , -1.6151114 ],
[ 0.14589922, -0.05858887, 0.15710317]],
[[-0.5025821 , -1.7865878 , -0.8327457 ],
[-0.539799 , -0.25900158, 0.31972417],
[ 0.5313025 , 0.3924203 , -0.47739467],
...,
[-0.18798655, -2.2831001 , 0.12444557],
[ 0.78169495, 0.33737284, 0.5514 ],
[-0.54365623, 0.56584615, 0.3617644 ]],
[[ 0.85397774, -1.1071197 , 1.4945678 ],
[-0.5415596 , -0.38532564, -2.5756788 ],
[ 1.2383386 , 1.6852653 , -0.44956163],
...,
[ 0.24406873, 1.6191622 , 0.47252372],
[-0.69918287, 1.6312635 , 1.0421402 ],
[-0.11589125, -0.14129055, 0.07662573]]]], dtype=float32)>
# 读取 R/G 通道数据:
x[...,:2]
<tf.Tensor: shape=(4, 32, 32, 2), dtype=float32, numpy=
array([[[[-1.4066522 , -1.6301475 ],
[ 0.4622193 , 1.1874142 ],
[ 1.1119522 , -1.336452 ],
...,
[ 0.76140714, 0.7055328 ],
[-0.69791937, -0.21970169],
[-0.7085722 , -0.7537329 ]],
[[ 0.05302444, 1.088168 ],
[-1.1256545 , -0.00637588],
[ 0.20962416, 0.1299456 ],
...,
[ 0.2610068 , 0.43658018],
[ 0.12355808, 0.41040298],
[ 1.0014896 , -0.39356777]],
[[ 0.6359369 , 1.4638699 ],
[-1.2632309 , 0.6778309 ],
[ 2.042887 , 0.10215978],
...,
[-0.5496323 , 0.8159724 ],
[-0.58229995, 0.31804055],
[ 0.19327746, -0.91908634]],
...,
[[ 0.57062095, -1.0924951 ],
[-0.17186154, -0.6705764 ],
[ 1.0782579 , -0.65768707],
...,
[-1.129991 , -0.9912013 ],
[ 0.9312622 , -1.6201673 ],
[ 1.1002243 , -0.8553034 ]],
[[ 0.44621316, 0.56181026],
[-0.5058765 , -0.06046636],
[ 0.8362112 , -0.01185856],
...,
[-0.26027304, 0.20625828],
[-0.9735298 , 0.24121265],
[ 0.29431424, 0.27458662]],
[[ 0.69613606, -0.60457057],
[ 0.69580555, 0.7907061 ],
[ 0.03631109, 0.5971228 ],
...,
[-0.38135353, 0.556312 ],
[ 0.9262408 , -0.57603973],
[ 0.27046308, 1.4378041 ]]],
[[[-1.31716 , 0.8927563 ],
[-0.733248 , -0.77084136],
[-1.0336089 , 0.6535478 ],
...,
[ 0.30430856, -0.11049574],
[-1.7981858 , -2.225117 ],
[ 0.93750566, -0.67847425]],
[[ 2.499113 , 1.6312623 ],
[ 1.1692234 , 0.0189321 ],
[-0.62398046, -0.38072154],
...,
[-0.32941985, 0.10051683],
[-0.7381755 , -1.4565924 ],
[ 1.6443965 , -0.58757645]],
[[-0.04231573, -1.0543141 ],
[-1.3289794 , 1.7232455 ],
[ 0.430539 , 0.751165 ],
...,
[ 1.3924987 , 0.8397924 ],
[ 0.8138152 , 1.8297083 ],
[-0.52527463, 0.0285456 ]],
...,
[[ 1.3123864 , -0.59172046],
[-0.5665263 , -0.4391703 ],
[ 0.07010391, -0.58571684],
...,
[ 2.1118972 , 0.23430105],
[ 0.32074732, -1.6758326 ],
[ 0.10189884, 0.3735951 ]],
[[-1.153347 , -1.3645796 ],
[-0.60928154, -1.3566399 ],
[ 0.5597972 , 1.2221434 ],
...,
[ 0.39691612, 2.0617461 ],
[-0.80052674, 0.13028495],
[ 0.53274167, 1.8119245 ]],
[[-0.42815545, -0.4821952 ],
[ 0.7978528 , -1.4759514 ],
[ 1.7242622 , -0.295613 ],
...,
[ 2.0360522 , 0.56676894],
[-1.8507787 , 0.24397703],
[-0.72610736, 0.27573547]]],
[[[ 1.6538881 , -1.2828442 ],
[ 0.6386152 , 1.1786485 ],
[-0.9385108 , 2.177739 ],
...,
[ 2.6009357 , -0.3140693 ],
[ 0.4681579 , -0.59527147],
[ 0.7524163 , -0.85284793]],
[[-1.0554148 , -0.09115129],
[-1.0196933 , 1.114568 ],
[-0.33567974, -1.3310446 ],
...,
[ 0.7729232 , 1.0367365 ],
[ 2.2981248 , 1.1215141 ],
[ 0.00336637, -0.16251333]],
[[-1.5069393 , -0.01386867],
[-0.75924855, -2.37886 ],
[ 0.7158936 , -0.24094783],
...,
[-0.29446223, -0.2361764 ],
[-0.4777188 , -1.6288629 ],
[ 1.4216906 , -1.92074 ]],
...,
[[-0.4792214 , 0.3826494 ],
[-0.6446881 , 0.7737799 ],
[ 0.6713453 , -2.136748 ],
...,
[-0.16307087, 0.357456 ],
[ 1.3242189 , -1.0513958 ],
[-0.36773506, -0.4142033 ]],
[[ 0.50397587, 0.9386607 ],
[-0.8921362 , -0.06021392],
[ 0.4871909 , 0.18244128],
...,
[-1.8799183 , 0.38841182],
[ 0.1810747 , -0.61945873],
[ 1.3666738 , 0.32145166]],
[[ 0.06317475, 0.15326717],
[ 0.01047049, 1.2195497 ],
[-0.50812846, 0.96157503],
...,
[-0.3401454 , 0.9257043 ],
[ 0.96981984, -0.37571394],
[-0.4275013 , -0.98839325]]],
[[[ 0.796133 , 0.5410011 ],
[ 1.627531 , 1.4550391 ],
[ 0.41060907, -0.20038849],
...,
[-1.4245912 , -1.0895957 ],
[-0.9552354 , -1.3100492 ],
[ 0.2587053 , 1.6795714 ]],
[[-1.6926944 , -1.088903 ],
[-1.2249197 , -0.2737821 ],
[ 0.10098427, 0.02635551],
...,
[ 1.9623247 , 0.8831638 ],
[ 0.6759971 , -0.11783656],
[ 0.8130058 , -0.22271885]],
[[ 1.8699164 , -0.73711467],
[ 0.00324396, 1.3924849 ],
[-1.597196 , 0.881057 ],
...,
[-1.7410468 , 0.35400176],
[ 0.520886 , 1.0219852 ],
[ 1.3913602 , -0.14510114]],
...,
[[ 0.33566687, -2.0631125 ],
[-0.03824278, 2.0184903 ],
[-0.38573766, -0.22726655],
...,
[-2.424693 , -0.9049791 ],
[-0.33129185, 1.0232396 ],
[ 0.14589922, -0.05858887]],
[[-0.5025821 , -1.7865878 ],
[-0.539799 , -0.25900158],
[ 0.5313025 , 0.3924203 ],
...,
[-0.18798655, -2.2831001 ],
[ 0.78169495, 0.33737284],
[-0.54365623, 0.56584615]],
[[ 0.85397774, -1.1071197 ],
[-0.5415596 , -0.38532564],
[ 1.2383386 , 1.6852653 ],
...,
[ 0.24406873, 1.6191622 ],
[-0.69918287, 1.6312635 ],
[-0.11589125, -0.14129055]]]], dtype=float32)>
五、维度变换
基本的维度变换包含了改变视图 reshape,插入新维度 expand_dims,删除维度
squeeze,交换维度 transpose,复制数据 tile 等
# 1.改变视图
x = tf.range(96)
x
x = tf.reshape(x, [2,4,4,3])
x
<tf.Tensor: shape=(96,), dtype=int32, numpy=
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])>
<tf.Tensor: shape=(2, 4, 4, 3), dtype=int32, numpy=
array([[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17],
[18, 19, 20],
[21, 22, 23]],
[[24, 25, 26],
[27, 28, 29],
[30, 31, 32],
[33, 34, 35]],
[[36, 37, 38],
[39, 40, 41],
[42, 43, 44],
[45, 46, 47]]],
[[[48, 49, 50],
[51, 52, 53],
[54, 55, 56],
[57, 58, 59]],
[[60, 61, 62],
[63, 64, 65],
[66, 67, 68],
[69, 70, 71]],
[[72, 73, 74],
[75, 76, 77],
[78, 79, 80],
[81, 82, 83]],
[[84, 85, 86],
[87, 88, 89],
[90, 91, 92],
[93, 94, 95]]]])>
# 可以通过张量的 ndim 和 shape 成员属性获得张量的维度数和形状:
x.ndim
x.shape
4
TensorShape([2, 4, 4, 3])
# 2.增删维度:只能增删长度为1的维度
x = tf.random.uniform([8,8], maxval=10, dtype=tf.int32)
print('原始维度;', x.shape)
# 增加维度
x = tf.expand_dims(x, axis=2)
print('在第2个索引前增加维度;', x.shape)
x = tf.expand_dims(x, axis=-3)
print('在倒数第3个索引前增加维度;', x.shape)
# 删除维度
x = tf.squeeze(x, axis=-1)
print('删除最后一个维度;', x.shape)
# 不指定维度参数axis时,默认删除所有长度为1的维度
x = tf.squeeze(x)
print('删除所有长度为1的维度;', x.shape)
原始维度; (8, 8)
在第2个索引前增加维度; (8, 8, 1)
在倒数第3个索引前增加维度; (8, 1, 8, 1)
删除最后一个维度; (8, 1, 8)
删除所有长度为1的维度; (8, 8)
# 3.交换维度:会改变张量存储顺序
x = tf.random.uniform([2,8,3,5], maxval=10, dtype=tf.int32)
x.shape
# perm表示新维度的索引顺序
tf.transpose(x, perm=[0, 3, 1, 2])
TensorShape([2, 8, 3, 5])
<tf.Tensor: shape=(2, 5, 8, 3), dtype=int32, numpy=
array([[[[7, 8, 9],
[2, 9, 0],
[0, 6, 2],
[2, 7, 7],
[8, 5, 1],
[8, 6, 0],
[7, 9, 0],
[2, 3, 2]],
[[3, 0, 9],
[3, 2, 9],
[8, 2, 4],
[1, 5, 7],
[1, 0, 6],
[2, 5, 8],
[7, 2, 8],
[6, 5, 8]],
[[4, 0, 2],
[1, 8, 2],
[8, 3, 6],
[9, 5, 2],
[2, 6, 5],
[9, 4, 3],
[6, 1, 0],
[7, 8, 4]],
[[2, 2, 0],
[2, 6, 5],
[3, 4, 8],
[0, 7, 2],
[1, 6, 3],
[7, 1, 7],
[5, 8, 7],
[6, 4, 9]],
[[0, 3, 1],
[6, 2, 9],
[6, 8, 5],
[1, 7, 2],
[2, 5, 0],
[4, 2, 1],
[9, 7, 3],
[0, 8, 4]]],
[[[4, 9, 2],
[9, 7, 0],
[4, 9, 4],
[6, 0, 9],
[5, 9, 8],
[5, 6, 7],
[7, 5, 7],
[7, 0, 9]],
[[4, 3, 2],
[1, 0, 9],
[0, 6, 7],
[5, 0, 0],
[0, 6, 5],
[4, 7, 7],
[8, 5, 0],
[1, 2, 6]],
[[1, 8, 9],
[1, 3, 1],
[6, 3, 6],
[3, 0, 7],
[9, 6, 8],
[5, 9, 2],
[9, 1, 6],
[7, 7, 6]],
[[4, 9, 2],
[7, 4, 0],
[2, 7, 0],
[6, 3, 1],
[4, 0, 8],
[7, 6, 6],
[0, 5, 5],
[9, 1, 5]],
[[6, 0, 8],
[8, 7, 4],
[9, 0, 5],
[6, 3, 6],
[5, 9, 4],
[2, 2, 4],
[2, 5, 7],
[2, 7, 0]]]])>
# 4.数据复制:tf.tile(x, multiples), multiples 分别指定了每个维度上面的复制倍数
# 例1
x = tf.range(6)
x = tf.reshape(x, [2, 3])
tf.tile(x, multiples=[1, 2])
tf.tile(x, multiples=[2, 1])
# 例2
y = tf.constant([1,2])
y.shape
y = tf.expand_dims(y, axis=0)
y.shape
tf.tile(y, multiples=[2, 1])
<tf.Tensor: shape=(2, 6), dtype=int32, numpy=
array([[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5]])>
<tf.Tensor: shape=(4, 3), dtype=int32, numpy=
array([[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5]])>
TensorShape([2])
TensorShape([1, 2])
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
[1, 2]])>
六、Broadcasting
# Broadcasting 机制都能通过优化手段避免实际复制数据而完成逻辑运算
A = tf.random.normal([5, 1])
A
tf.broadcast_to(A, [2, 5, 5, 4])
# 基本运算自动调用broadcast
tf.constant([[0,1,2,3]])+tf.constant([[0],[1],[2]])
<tf.Tensor: shape=(5, 1), dtype=float32, numpy=
array([[ 1.546976 ],
[ 1.080959 ],
[-1.1580259 ],
[-0.4453214 ],
[ 0.27686113]], dtype=float32)>
<tf.Tensor: shape=(2, 5, 5, 4), dtype=float32, numpy=
array([[[[ 1.546976 , 1.546976 , 1.546976 , 1.546976 ],
[ 1.080959 , 1.080959 , 1.080959 , 1.080959 ],
[-1.1580259 , -1.1580259 , -1.1580259 , -1.1580259 ],
[-0.4453214 , -0.4453214 , -0.4453214 , -0.4453214 ],
[ 0.27686113, 0.27686113, 0.27686113, 0.27686113]],
[[ 1.546976 , 1.546976 , 1.546976 , 1.546976 ],
[ 1.080959 , 1.080959 , 1.080959 , 1.080959 ],
[-1.1580259 , -1.1580259 , -1.1580259 , -1.1580259 ],
[-0.4453214 , -0.4453214 , -0.4453214 , -0.4453214 ],
[ 0.27686113, 0.27686113, 0.27686113, 0.27686113]],
[[ 1.546976 , 1.546976 , 1.546976 , 1.546976 ],
[ 1.080959 , 1.080959 , 1.080959 , 1.080959 ],
[-1.1580259 , -1.1580259 , -1.1580259 , -1.1580259 ],
[-0.4453214 , -0.4453214 , -0.4453214 , -0.4453214 ],
[ 0.27686113, 0.27686113, 0.27686113, 0.27686113]],
[[ 1.546976 , 1.546976 , 1.546976 , 1.546976 ],
[ 1.080959 , 1.080959 , 1.080959 , 1.080959 ],
[-1.1580259 , -1.1580259 , -1.1580259 , -1.1580259 ],
[-0.4453214 , -0.4453214 , -0.4453214 , -0.4453214 ],
[ 0.27686113, 0.27686113, 0.27686113, 0.27686113]],
[[ 1.546976 , 1.546976 , 1.546976 , 1.546976 ],
[ 1.080959 , 1.080959 , 1.080959 , 1.080959 ],
[-1.1580259 , -1.1580259 , -1.1580259 , -1.1580259 ],
[-0.4453214 , -0.4453214 , -0.4453214 , -0.4453214 ],
[ 0.27686113, 0.27686113, 0.27686113, 0.27686113]]],
[[[ 1.546976 , 1.546976 , 1.546976 , 1.546976 ],
[ 1.080959 , 1.080959 , 1.080959 , 1.080959 ],
[-1.1580259 , -1.1580259 , -1.1580259 , -1.1580259 ],
[-0.4453214 , -0.4453214 , -0.4453214 , -0.4453214 ],
[ 0.27686113, 0.27686113, 0.27686113, 0.27686113]],
[[ 1.546976 , 1.546976 , 1.546976 , 1.546976 ],
[ 1.080959 , 1.080959 , 1.080959 , 1.080959 ],
[-1.1580259 , -1.1580259 , -1.1580259 , -1.1580259 ],
[-0.4453214 , -0.4453214 , -0.4453214 , -0.4453214 ],
[ 0.27686113, 0.27686113, 0.27686113, 0.27686113]],
[[ 1.546976 , 1.546976 , 1.546976 , 1.546976 ],
[ 1.080959 , 1.080959 , 1.080959 , 1.080959 ],
[-1.1580259 , -1.1580259 , -1.1580259 , -1.1580259 ],
[-0.4453214 , -0.4453214 , -0.4453214 , -0.4453214 ],
[ 0.27686113, 0.27686113, 0.27686113, 0.27686113]],
[[ 1.546976 , 1.546976 , 1.546976 , 1.546976 ],
[ 1.080959 , 1.080959 , 1.080959 , 1.080959 ],
[-1.1580259 , -1.1580259 , -1.1580259 , -1.1580259 ],
[-0.4453214 , -0.4453214 , -0.4453214 , -0.4453214 ],
[ 0.27686113, 0.27686113, 0.27686113, 0.27686113]],
[[ 1.546976 , 1.546976 , 1.546976 , 1.546976 ],
[ 1.080959 , 1.080959 , 1.080959 , 1.080959 ],
[-1.1580259 , -1.1580259 , -1.1580259 , -1.1580259 ],
[-0.4453214 , -0.4453214 , -0.4453214 , -0.4453214 ],
[ 0.27686113, 0.27686113, 0.27686113, 0.27686113]]]],
dtype=float32)>
<tf.Tensor: shape=(3, 4), dtype=int32, numpy=
array([[0, 1, 2, 3],
[1, 2, 3, 4],
[2, 3, 4, 5]])>
七、数学运算
a = tf.range(1,5)
b = tf.constant(2)
# 加
tf.add(a, b)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([3, 4, 5, 6])>
# 减
tf.subtract(a, b)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([-1, 0, 1, 2])>
# 乘
tf.multiply(a, b)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 4, 6, 8])>
# 除
tf.divide(a, b)
<tf.Tensor: shape=(4,), dtype=float64, numpy=array([0.5, 1. , 1.5, 2. ])>
# 整除//
a//b
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 1, 2])>
# 余除%
a%b
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 0, 1, 0])>
# 开方运算(浮点数)
b = tf.cast(a, tf.float32)
tf.pow(b, 0.5)
b**0.5
# 平方根运算(浮点数)
tf.sqrt(b)
<tf.Tensor: shape=(4,), dtype=float32, numpy=array([1. , 1.4142135, 1.7320508, 2. ], dtype=float32)>
<tf.Tensor: shape=(4,), dtype=float32, numpy=array([1. , 1.4142135, 1.7320508, 2. ], dtype=float32)>
<tf.Tensor: shape=(4,), dtype=float32, numpy=array([0.99999994, 1.4142134 , 1.7320508 , 1.9999999 ], dtype=float32)>
# 平方运算
tf.square(a)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 4, 9, 16])>
# 指数运算
tf.pow(a, 3)
a**3
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 8, 27, 64])>
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 8, 27, 64])>
# 自然指数(浮点数)
tf.exp(1.)
# 自然对数(浮点数)
tf.math.log(3.)
# 计算以10为底数的对数
x = tf.constant([1., 2.])
tf.math.log(x)/tf.math.log(10.)
<tf.Tensor: shape=(), dtype=float32, numpy=2.7182817>
<tf.Tensor: shape=(), dtype=float32, numpy=1.0986123>
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0. , 0.30102998], dtype=float32)>
# 矩阵相乘:@或者tf.matmul()
# 张量维度大于2时,对最后两个维度进行矩阵相乘
a = tf.random.uniform([2,2,5,3], maxval=2, dtype=tf.int32)
b = tf.random.uniform([2,2,3,1], maxval=3, dtype=tf.int32)
a@b
# 矩阵相乘支持broadcast机制
a = tf.random.normal([4,7,32])
b = tf.random.normal([32,3])
tf.matmul(a,b)
<tf.Tensor: shape=(4, 7, 3), dtype=float32, numpy=
array([[[ -1.2468381 , -2.0166807 , 7.9159975 ],
[ 0.13711262, 6.2233095 , -11.160522 ],
[ 8.794029 , -1.8568348 , 2.6700351 ],
[ 6.054277 , -3.0855184 , 4.326208 ],
[ 5.6040177 , 2.3851852 , 4.5705433 ],
[ 11.296453 , -11.705642 , -3.8169818 ],
[ 1.0899665 , 4.670338 , -1.3397212 ]],
[[ -3.009362 , 3.5301056 , 11.140211 ],
[ -7.292059 , -2.47621 , 3.587782 ],
[ 1.6744664 , -2.3763402 , -7.3275967 ],
[ 2.894739 , 6.116 , 2.2546031 ],
[ -0.32890558, -1.567178 , -0.5077249 ],
[ 7.892525 , 0.49388677, -3.1761234 ],
[ 2.5883071 , -2.1991527 , 2.1398368 ]],
[[ -8.399379 , 2.7203753 , 7.176408 ],
[ 0.4008199 , -4.5117955 , -0.8481548 ],
[ 1.8558321 , 4.4049177 , -0.5501193 ],
[ -0.44232368, -1.6504501 , -16.607563 ],
[-12.964595 , 0.9423628 , -1.4471778 ],
[ 1.2180306 , 7.1369076 , -3.0553536 ],
[-10.172453 , 0.11261189, 6.4622445 ]],
[[-11.501997 , 3.9855576 , 6.041157 ],
[ -3.1693828 , 5.0808306 , -5.7351456 ],
[ -5.2635756 , -3.1858814 , -2.893431 ],
[ -2.303886 , -0.47508943, 2.8748848 ],
[ -4.2150383 , -0.06244525, 0.10371459],
[ 6.105155 , -10.30731 , 1.856529 ],
[ -4.4830313 , 1.4824102 , 1.9637489 ]]], dtype=float32)>
八、前向传播实战:完成三层神经网络的实现
# o𝑢𝑡 = 𝑟𝑒𝑙𝑢{𝑟𝑒𝑙𝑢{𝑟𝑒𝑙𝑢[𝑋@𝑊1 + 𝑏1]@𝑊2 + 𝑏2}@𝑊 + 𝑏}
# 首先创建每层的W、b张量
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1))
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))
# 在前向计算时,首先将 shape 为[𝑏, 28,28]的输入数据 Reshape 为[𝑏, 784]
x = tf.random.normal([100, 28, 28])
x = tf.reshape(x, [-1, 784])
# 样本标签
y = tf.random.uniform([100], minval=0, maxval=10, dtype=tf.int32)
y_onehot = tf.one_hot(y, depth=10)
# 保存计算图信息,方便方向求导
with tf.GradientTape() as tape:
# 前向传播·
h1 = x@w1 + b1
h1 = tf.nn.relu(h1)
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2)
out = h2@w3 + b3
# 计算损失
loss = tf.square(y_onehot - out)
loss = tf.reduce_mean(loss)
# # tape.gradient()函数求得网络参数到梯度信息
grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
# 来更新网络参数
# assign_sub()将原地(In-place)减去给定的参数值,实现参数的自我更新操作
lr = tf.constant(0.001)
w1.assign_sub(lr * grads[0])
b1.assign_sub(lr * grads[1])
w2.assign_sub(lr * grads[2])
b2.assign_sub(lr * grads[3])
w3.assign_sub(lr * grads[4])
b3.assign_sub(lr * grads[5])
侵删