如何过滤基于带索引张量的张量流张量?
问题描述:
比方说,我有一个尺寸为[batch_size, 5, 10]
的张量,称为my_tensor
。 我还有一个尺寸为[batch_size, 1]
的另一个张量,其中包含一个名为selecter
的索引。如何过滤基于带索引张量的张量流张量?
我想对于过滤my_tensor
到selecter
生产规模[batch_size, 10]
新张量,即只选择珍视selecter
包含。基本上,它有点减少中间维度(其大小为5)。我觉得tf.where
是正确的选择,但不确定。 我真的很感谢你的帮助!
答
解决方法是使用tf.gather_nd
。
tf.gather_nd(
my_tensor,
tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1))
如果你构建selecter
是从一开始1-d可以摆脱squeeze
的。
答
替代的解决方案,工作在Tensorflow 1.3:
max_selecter = tf.reduce_max(selecter) + 1
my_tensor = tf.boolean_mask(
outputs,
tf.logical_xor(
tf.sequence_mask(my_tensor + 1, max_selecter),
tf.sequence_mask(my_tensor, max_selecter)
)
)
这是完美的。非常感谢你! –
你用什么版本的tensorflow?我有1.3.0和我的tf.gather_nd不接受轴参数。但是,有这个tf.gather。 – omikron