提取从张量的特定元素在tensorflow
问题描述:
我使用tensorflow有关python 我的形状的数据张量[?,5,37],和形状的IDX张量[?,5]提取从张量的特定元素在tensorflow
我倒要提取的数据元素,并获得的形状的输出,使得[,5']:
output[i][j] = data[i][j][idx[i, j]] for all i in range(?) and j in range(5)
它看起来洛克的tf.gather_nd()函数是最接近我的需要,但我不t看看我的情况如何使用它...
谢谢!
编辑:我设法做到了与gather_nd如下所示,但有没有更好的选择? (似乎有点重手)
nRows = tf.shape(length_label)[0] ==> ?
nCols = tf.constant(MAX_LENGTH_INPUT + 1, dtype=tf.int32) ==> 5
m1 = tf.reshape(tf.tile(tf.range(nCols), [nRows]),
shape=[nRows, nCols])
m2 = tf.transpose(tf.reshape(tf.tile(tf.range(nRows), [nCols]),
shape=[nCols, nRows]))
indices = tf.pack([m2, m1, idx], axis=-1)
# indices should be of shape [?, 5, 3] with indices[i,j]==[i,j,idx[i,j]]
output = tf.gather_nd(data, indices=indices)
答
我设法与gather_nd
做如下图所示
nRows = tf.shape(length_label)[0] # ==> ?
nCols = tf.constant(MAX_LENGTH_INPUT + 1, dtype=tf.int32) # ==> 5
m1 = tf.reshape(tf.tile(tf.range(nCols), [nRows]),
shape=[nRows, nCols])
m2 = tf.transpose(tf.reshape(tf.tile(tf.range(nRows), [nCols]),
shape=[nCols, nRows]))
indices = tf.pack([m2, m1, idx], axis=-1)
# indices should be of shape [?, 5, 3] with indices[i,j]==[i,j,idx[i,j]]
output = tf.gather_nd(data, indices=indices)
您的解决方案对我来说很好。 – user1454804