tf.strided_slice

 tf.strided_slice  |  TensorFlow Core v2.8.0Extracts a strided slice of a tensor (generalized Python array indexing).https://tensorflow.google.cn/api_docs/python/tf/strided_slice

import tensorflow as tf
import numpy as np
import warnings
warnings.filterwarnings(action="ignore")


input = np.array(
    [[["000", "001", "002", "003"], 
      ["010", "011", "012", "013"], 
      ["020", "021", "022", "023"]],
    [["100", "101", "102", "103"], 
      ["110", "111", "112", "113"], 
      ["120", "121", "122", "123"]]])   # 2x3x4


# if strides is None: strides = np.ones_like(begin), 设为全1
# strides can be any integer but 0


# ------------- 正向切片
# begin[i]/end[i] 取值范围[-input_shape[i], input_shape[i]-1], 超出边界的按照边界值处理
# strides[i]>0, begin[i]<end[i]
# 切片取值区间:[begin[i], end[i]),每间隔 (strides[i]-1) 取一个点
begin = np.array([0, 0, 0])   
end = np.array([2, 3, -1])
strides = np.array([1, 2, 1])
# (2,2,3)


# ------------- 反向切片
# strides[i]<0,begin[i]>end[i]
# 切片取值区间:(end[i], begin[i]], 倒着取
begin = np.array([1, 0, 3])
end = np.array([-4, 3, 0])
strides = np.array([-1, 1, -1])
# (2,3,3)


# end是开区间,-1表示‘3’轴上的元素,因此轴上最后一个点取不到,结果不带'3'
begin = np.array([0, 0, 0])
end = np.array([-1, -1, -1])
strides = np.array([1, 1, 1])
# (1,2,3)


# end开区间,-1表示‘0’轴上的元素,因此取不到带'0'的元素
begin = np.array([1, -1, -1])
end = np.array([0, 0, 0])
strides = np.array([-1, -1, -1])
# (1,2,3)


# mask =default: (1,2,2)
# 掩码转换成二进制格式,从末位取,倒叙,第i个值为1,则表示begin[i]/end[i]是否无效
# begin_mask=-3[101]: (2,2,3)
# begin_mask=-2[011]: (1,2,3)
# begin_mask= 1[110]: (2,2,2)
# begin_mask= 2[010]: (1,2,2)
begin = np.array([1, 0, 1])
end = np.array([2, 2, 3])
strides = np.array([1, 1, 1])


# new_axis_mask 掩码升维,default:(1,2,3),插入的索引有效范围[0,1,2],越界不处理
# new_axis_mask=1(100): (1,2,3,4), begin[0:],end[0:],stride[0:] ignored
# new_axis_mask=2(010): (1,1,3,4), begin[1:],end[1:],stride[1:] ignored
# new_axis_mask=4(001): (1,2,1,4), begin[2:],end[2:],stride[2:] ignored
# new_axis_mask=8(0001):(1,2,3), begin[3:],end[3:],stride[3:] ignored,插入的索引越界
# new_axis_mask=13(1011):(1,2,1,3,4), 二进制末位索引越界,不处理
# new_axis_mask=23(11101):(1,1,1,2,3,4) 二进制末两位索引越界,不处理
begin = np.array([0, 0, 0])
end = np.array([1, 2, 3])
strides = np.array([1, 1, 1])


# shrink_axis_mask 掩码降维,default:(1,2,3),索引越界不处理
# shrink_axis_mask=1(100): (2,3)
# shrink_axis_mask=2(010): (1,3)
# shrink_axis_mask=4(001): (1,2)
# shrink_axis_mask=3(110): (3,)
# shrink_axis_mask=7(111): scaler, '000'
# shrink_axis_mask=11(1101): (3,), 二进制末位索引越界,不处理
begin = np.array([0, 0, 0])
end = np.array([1, 2, 3])
strides = np.array([1, 1, 1])


# ellipsis_mask: 省略掩码,表示二进制位数为1的index,out_shape[index]=":"
# 例:A_shape=[2,3,4], A[1:2,:,3:4].shape =[1,3,1], 

res = tf.strided_slice(
    input_=input,
    begin=begin,
    end=end,
    strides=strides,
    begin_mask=0,
    end_mask=0,
    ellipsis_mask=0,
    new_axis_mask=4,
    shrink_axis_mask=0,
    var=None,
    name=None
)

with tf.Session() as sess:
    out = sess.run(res)
    print(out)
    print(out.shape)

官方样例理解

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值