Speeding up Inference time on GPT2 - optimizing tf.sess.run()

335 Views Asked by At

I am trying to optimize the inference time on GPT2. The current time to generate a sample after calling the script is 55 secs on Google Colab. I put in timestamps to try to isolate where the bottleneck is. This is the code:

 for _ in range(nsamples // batch_size):
            out = sess.run(output, feed_dict={
                context: [context_tokens for _ in range(batch_size)]
            })[:, len(context_tokens):]
            for i in range(batch_size):
                generated += 1
                text = enc.decode(out[i])
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                print(text)
        print("=" * 80)

The line

out = sess.run(output, feed_dict={
                context: [context_tokens for _ in range(batch_size)]
            })[:, len(context_tokens):] 

is where the complexity lies. Does anyone have any way I can improve this piece of code ? Thank you so much!

1

There are 1 best solutions below

1
On

batch_size is set to 1 in GPT2 and there is no way to change that without crashing the process. So "[context_tokens for _ in range(batch_size)]" means "[context_tokens for _ in range(1)]" means "[context_tokens]" which will not improve speed by much but is safe to implement and makes looking at the code a bit more sensible. The real complexty is you have a 6 gigabyte bohemoth in your ram that you are accessing in that session.

As a practical matter, the less tokens you send over and the less processing those tokens take the faster this part will execute. As each token needs to be sent through the GPT2 AI. But consequently the less 'intelligent' the response will be.

By the way // is an integer division operation, so nsamples // batch_size = nsamples/1 = nsamples size. And from what I have seen the nsamples was 1 when I printed its value in print(nsamples). So that for loop is another loop of one item, which means the loop can be removed.

GPT2 is just a implementation of tensorflow. Look up: how to make a graph in tensorflow; how to call a session for that graph; how to make a saver save the variables in that session and how to use the saver to restore the session. You will learn about checkpoints, meta files and other implementation that will make your files make more sense.

The tensorflow module is found in Lib, site-packages, tensorflow_core (at least in the AI Dungeon 2 Henk717 fork). Most of the processing is happening in sub directories python/ops and framework. You will see these pop up if your coding breaks the hooks tf was expecting.

If this question regards the implementation in AI Dungeon the best I have been able to implement is a recursive call to generator.generate that is exited by a try except KeyboardInterrupt: with a print(token, end = '', flush = True) for each token as they are generated. This way you are able to view each token as the AI generates it, rather that waiting for 55 sec for a ping sound.

Also, the Cuda warnings need a single quote, not double so, import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' not "3" That will take off the cuda warnings when tensorflow is imported.

Next there are depreciations that popup from the implementation of GPT2 in tensorflow versions above 1.5.

To shut those off tfv = tf.compat.v1 tfv.set_verbosity(tfv.logging.Error) Is all you need. You don't need to import warnings.

Even so it is a long load time between the tf initialization, the sample initial generation and the loading of the module into ram. I added in model.shape_list(x): the followin line print("_",end ='',flush=True) And at least for the module being built to localize it to the machine you can view a "progress bar" of sorts.