IndexedSlices in Tensorflow

Run in Google Colab View source on GitHub Download notebook

In this post, we will discuss about IndexedSlices class of Tensorflow. We will try to answer the following questions in this blog:

What are IndexedSlices?

According to Tensorflow documentation, IndexedSlices are sparse representation of a set of tensor slices at a given index. At an high level it appears to be some kind of sparse representation. Let’s try to understand it with examples.

Where do we get it?

We get IndexedSlices while taking gradients of an Embedding layer. Embedding matrices can be huge (depending on vocabulary size). But each batch only contains a small fraction of tokens. So while computing the gradient of loss with respect to embedding layer, in each pass we have to only consider the corresponding token embeedings of the present batch. Naturally a sparse tensor seems to be a better option to record those gradients. Tensorflow does that using IndexedSlices. We will show that below using a contrived example.

import tensorflow as tf
print("Tensorflow version: ", tf.__version__)
Tensorflow version:  2.4.0
model = tf.keras.models.Sequential([
    # Vocab size: 10, Embedding dimension: 4, Input_shape size: (batch_size, num_words). As usual, batch_size is omitted.
    tf.keras.layers.Embedding(10, 4, input_shape = (5,)),
    tf.keras.layers.Flatten(), 
    tf.keras.layers.Dense(1)
])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, 5, 4)              40        
_________________________________________________________________
flatten (Flatten)            (None, 20)                0         
_________________________________________________________________
dense (Dense)                (None, 1)                 21        
=================================================================
Total params: 61
Trainable params: 61
Non-trainable params: 0
_________________________________________________________________
data = tf.random.uniform(shape = (1, 5), minval = 0, maxval = 10, dtype = tf.int32) # Batch size is 1.
data
<tf.Tensor: shape=(1, 5), dtype=int32, numpy=array([[6, 1, 1, 4, 8]])>
model.variables  # Is a list of 3 tensors. 1 from Embedding layer and 2 from Dense layer (Kernel and bias)
[<tf.Variable 'embedding/embeddings:0' shape=(10, 4) dtype=float32, numpy=
 array([[ 4.10897247e-02, -2.48962641e-03,  1.26880072e-02,
          3.39310430e-02],
        [ 3.28579657e-02,  3.90318781e-03,  2.81411521e-02,
          3.09719704e-02],
        [ 1.16247907e-02, -1.41257644e-02, -3.36343870e-02,
         -4.41543460e-02],
        [-4.67238426e-02,  2.42819674e-02, -4.26802635e-02,
         -2.59207971e-02],
        [ 2.28367783e-02, -2.09717881e-02,  1.05572566e-02,
          3.33249308e-02],
        [-3.37148309e-02, -4.61939685e-02, -2.61853095e-02,
         -4.10162285e-03],
        [-3.59787717e-02,  2.78765075e-02, -3.16200405e-02,
          4.54976298e-02],
        [-4.67344411e-02, -1.30221620e-02,  1.52915232e-02,
          2.22466923e-02],
        [-1.03901625e-02,  2.40740217e-02, -1.24427900e-02,
          4.47194651e-03],
        [-3.57637033e-02,  4.28059734e-02, -2.59280205e-05,
          4.09286283e-02]], dtype=float32)>,
 <tf.Variable 'dense/kernel:0' shape=(20, 1) dtype=float32, numpy=
 array([[ 0.42870212],
        [ 0.04779923],
        [ 0.4126016 ],
        [-0.13294601],
        [-0.3175783 ],
        [-0.46080017],
        [-0.23412797],
        [ 0.30137837],
        [-0.5197849 ],
        [-0.10935467],
        [ 0.5087845 ],
        [-0.06930307],
        [ 0.10028934],
        [-0.11278141],
        [-0.21269777],
        [-0.0214209 ],
        [ 0.12959635],
        [-0.13330323],
        [-0.23972857],
        [ 0.23718971]], dtype=float32)>,
 <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]
optimizer = tf.keras.optimizers.SGD(learning_rate = 0.1)
loss_object = tf.keras.losses.MeanSquaredError()
target = tf.constant([2.5], shape = (1,1))
for _ in range(2):   # Let's run gradient descent for two batches of the same input data. (It's a contrived examples)
    with tf.GradientTape() as tape:
        output = model(data) # Output has shape: (batch_size, 1). Here batch_size is 1. So output shape is (1,1)
        loss_value = loss_object(target, output)  # Calculating some random loss.
    grads = tape.gradient(loss_value, model.trainable_variables)

    # Gradient descent step
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
len(grads)
3
grads[0]
<tensorflow.python.framework.indexed_slices.IndexedSlices at 0x16c04d4a970>
print(grads[0])
IndexedSlices(indices=tf.Tensor([6 1 1 4 8], shape=(5,), dtype=int32), values=tf.Tensor(
[[-0.9495101  -0.14344962 -0.91739434  0.25398374]
 [ 0.69607455  1.0615798   0.5085497  -0.7338184 ]
 [ 1.1639317   0.24842012 -1.2103697   0.12384857]
 [-0.25895947  0.2856651   0.47968888  0.01028775]
 [-0.28760925  0.28005898  0.56933826 -0.5540699 ]], shape=(5, 4), dtype=float32), dense_shape=tf.Tensor([10  4], shape=(2,), dtype=int32))

An IndexedSlices object has 3 main entries.

  • indices
  • values, and
  • dense_shape

How to convert IndexedSlices to Tensors?

Before we do the conversion, let’s answer a relevant question: Why do we have to do the conversion from IndexedSlices to tensors given that Tensorflow can do a gradient descent step automatically through the IndexedSlices? In the last section, we could run 2 gradient descent steps without worrying about IndexedSlices.

But the problem occurs if we want to do some processing on gradient values. One such processing is gradient clipping. In gradient clipping, if sum of norm of gradients exceed a given value, gradients are rescaled to decrease their magnitude. Therefore, to do any gradient clipping, we have to access the gradient tensors. This is precisely where we would like to convert IndexedSlices to tensors. Having an embedding layer is common in deep learning models and applying gradient clipping to gradient values is also a common practice. We will show two approaches to do the conversion.

Easiest approach

tf.convert_to_tensor(grads[0])
<tf.Tensor: shape=(10, 4), dtype=float32, numpy=
array([[ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 1.8600063 ,  1.31      , -0.70182   , -0.60996985],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [-0.25895947,  0.2856651 ,  0.47968888,  0.01028775],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [-0.9495101 , -0.14344962, -0.91739434,  0.25398374],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [-0.28760925,  0.28005898,  0.56933826, -0.5540699 ],
       [ 0.        ,  0.        ,  0.        ,  0.        ]],
      dtype=float32)>

What did just happen in the last step?

Though the last approach is a single line elegant solution, it hides many things. How actually is the conversion done? The code below shows the steps in which we can manually do the conversion.

check_grad = tf.zeros_like(model.variables[0]).numpy()   # Create a dense tensor of all zeros
for i, ind in enumerate(grads[0].indices):
    check_grad[ind] = check_grad[ind] + grads[0].values[i]
check_grad
array([[ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 1.8600063 ,  1.31      , -0.70182   , -0.60996985],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [-0.25895947,  0.2856651 ,  0.47968888,  0.01028775],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [-0.9495101 , -0.14344962, -0.91739434,  0.25398374],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [-0.28760925,  0.28005898,  0.56933826, -0.5540699 ],
       [ 0.        ,  0.        ,  0.        ,  0.        ]],
      dtype=float32)

This brings us to the end of this blog. I hope this blog has demystified a few things about IndexedSlices.

Motivation for this post: While writing TF 2 code for Attention Mechanisms chapter of D2L book, the author encountered an error involving IndexedSlices. After spending a good deal of time hopelessly trying to figure out what’s going on, the author finally found that the error was occurring because of an user defined gradient clipping function that didn’t handle IndexedSlices properly. The model involved embedding layers as it was dealing with machine translation task. Therefore, I thought of writing this blog with the hope that it would be of help to readers who are struggling to figure out what IndexedSlices are.

Biswajit Sahoo
Biswajit Sahoo
Machine Learning Engineer

My research interests include machine learning, deep learning, signal processing and data-driven machinery condition monitoring.

Related