Skip to content Skip to sidebar Skip to footer

In Tensorflow, How To Use Tf.gather() For The Last Dimension?

I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor's shape is [batch_size, h, w, depth], I wan

Solution 1:

As of TensorFlow 1.3 tf.gather has an axis parameter, so the various workarounds here are no longer necessary.

https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gatherhttps://github.com/tensorflow/tensorflow/issues/11223

Solution 2:

There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206

For now you can:

  1. transpose your matrix so that dimension to gather is first (transpose is expensive)

  2. reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back

  3. use gather_nd. Will still need to turn your column indices into list of individual element indices.

Solution 3:

With gather_nd you can now do this as follows:

cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)

Also, as reported by user Nova in a thread referenced by @Yaroslav Bulatov's:

x = tf.constant([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]),  # flatten input
              idx_flattened)  # use flattened indices

with tf.Session(''):
  print y.eval()  # [249]

The gist is flatten the tensor and use strided 1D addressing with tf.gather(...).

Solution 4:

Yet another solution using tf.unstack(...), tf.gather(...) and tf.stack(..)

Code:

import tensorflow as tf
import numpy as np

shape = [2, 2, 2, 10] 
L = np.arange(np.prod(shape))
L = np.reshape(L, shape)

indices = [0, 2, 3, 8]
axis = -1# last dimensiondefgather_axis(params, indices, axis=0):
    return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)

print(L)
with tf.Session() as sess:
    partL = sess.run(gather_axis(L, indices, axis))
    print(partL)

Result:

L = 
[[[[ 0123456789]
   [10111213141516171819]]

  [[20212223242526272829]
   [30313233343536373839]]]


 [[[40414243444546474849]
   [50515253545556575859]]

  [[60616263646566676869]
   [70717273747576777879]]]]

partL = 
[[[[ 0238]
   [10121318]]

  [[20222328]
   [30323338]]]


 [[[40424348]
   [50525358]]

  [[60626368]
   [70727378]]]]

Solution 5:

A correct version of @Andrei's answer would read

cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1)
result = tf.gather_nd(matrix, cat_idx)

Post a Comment for "In Tensorflow, How To Use Tf.gather() For The Last Dimension?"