tf.slice()和 tf.gather()的用法
之前讲过 tf.gather 根据索引去找相应维度的子集
tf.slice(input,begin,size,name=None) 按照指定的下标范围抽取连续区域的子集
tf.gather(input,begin,size.name=None) 按照指定的下标集合从 axis=0 中抽取子集,适合抽取不连续区域的子集。
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
tf.slice(t, [1, 0, 0], [1, 1, 3]) # [[[3, 3, 3]]]
tf.slice(t, [1, 0, 0], [1, 2, 3]) # [[[3, 3, 3],
# [4, 4, 4]]]
tf.slice(t, [1, 0, 0], [2, 1, 3]) # [[[3, 3, 3]],
# [[5, 5, 5]]]
tf.gather(t, [0, 2]) # [[[1,1,1],[2,2,2]],
[[5,5,5],[6,6,6]]]
从 t 中抽取 [[[3,3,3]]],输出在 input 中的 axis=0(竖向坐标)是1,axis=1(横向坐标)是0,axis=2的坐标是 0-2 (包了几层),所以 begin = [1,0,0], size=[1,1,3](begin就是开始坐标是1,0,0,size是取出来的[3,3,3],是[1,1,3]一行一列三个数字)
从 t 中抽取[[[3,3,3],[4,4,4]]],begin = [1,0,0], size=[1,2,3]
从 t 中抽取[[[1,1,1],[2,2,2]],[5,5,5],[6,6,6]]],axis = 0的下标是[0,2],不连续,使用tf.gather().