Tensorflow(一) —— 主要的数据类型
1 主要的数据类型
- list
- np.array 同类型数据运算,没有GPU支持
- tf.Tensor
2 Tensor的种类
- (1) Scalar: 1.1 标量
- (2) Vector: [1.1] [1.1,2.2,…]
- (3) matrix: [[1.1,2.2,],[3.3,4.4]]
- (4) Tensor: ndim > 2
- (5) 一般tensor指以上所有数据
3 基本数据类型
- (1) int float double
- (2) bool
- (3) string
4 创建数据类型实例
4.1 创建 int类型
a = tf.constant(1)
print("a:",a)
4.2 创建float类型
b = tf.constant(2.)
print("b:",b)
4.3 类型使用不当
try:
c = tf.constant(2.5,dtype = tf.int32)
except Exception as error:
print(error)
4.4 创建布尔类型
d = tf.constant([True,False])
print("d:",d)
4.5 创建字符串类型
e = tf.constant("Welcome to TensorFlow!")
print("e:",e)
5 Tensor常见属性
5.1 device属性
with tf.device("cpu"):
a = tf.constant(5)
with tf.device("gpu"):
b = tf.constant(6)
print("a",a.device)
print("b",b.device)
5.2 CPU和GPU转移
aa = a.gpu()
print("aa:",aa.device)
bb = b.cpu()
print("bb:",bb.device)
"""
GPU上的tensor只能在GPU上操作
CPU上的Tenor只能在CPU上操作
"""
try:
print(bb+aa)
except Exception as error:
print(error)
5.3 Tensor转化为numpy:
print(aa.numpy(),type(aa.numpy()))
5.4 形状和维度
print(aa.shape)
"""
标量的形状为0
"""
t1 = tf.ones([5,6])
print("t1:",t1.shape)
print("t1:",t1.ndim)
print("t1:",tf.rank(t1)) # 同时返回shape和ndim
5.5 判断一个对象是否为Tensor
print("t1:",tf.is_tensor(t1))
print("t1:",isinstance(t1,tf.Tensor))
5.6 查看数据类型
print("t1:",t1.dtype)
5.7 数据类型判断
print("t1:",t1.dtype == tf.int32)
6 数据类型之间的相互转换
6.1 numpy转tensorflow
a = np.arange(10)
print("a:",a.dtype,a)
b = tf.convert_to_tensor(a)
print("b:",b.dtype,b)
c = tf.convert_to_tensor(a,dtype = tf.int16)
print("c:",c.dtype,c)
6.2 tensor类型间互转
t1 = tf.constant(5)
print("t1:",t1)
t11 = tf.cast(t1,tf.float32)
print("t11:",t11)
6.3 整型和布尔类型之间相互转换
t2 = tf.constant([0,1])
print("t2:",t2)
t21 = tf.cast(t2,tf.bool)
print("t21:",t21)
t22 = tf.cast(t21,tf.int32)
print("t22:",t22)
6.4 特殊的数据类型 tf.Variable
t3 = tf.range(5)
print("t3:",t3)
t31 = tf.Variable(t3)
print("t31:",t31)
print("t31:",t31.name)
print("t31:",t31.trainable) # 判断是否为变量,即是否可以求导
print('t31:',tf.is_tensor(t31))
print("t31:",isinstance(t31,tf.Tensor))
"""
所以用is_tensor方法更好更为准确
"""
6.6 转换为numpy数据类型
t4 = tf.range(1,100,5)
print("t4:",t4)
t41 = t4.numpy()
print("t41:",t41.dtype,t41)
本文为参考龙龙老师的“深度学习与TensorFlow 2入门实战“课程书写的学习笔记
by CyrusMay 2022 04 06