I am running a model with a very big word embedding (>2M words). When I use tf.embedding_lookup, it expects the matrix, which is big. When I run, I subsequently get out of GPU memory error. If I reduce the size of the embedding, everything works fine.
Is there a way to deal with larger embedding?
The recommended way is to use a partitioner to shard this large tensor across several parts:
This will split the tensor into 3 shards along 0 axis, but the rest of the program will see it as an ordinary tensor. The biggest benefit is to use a partitioner along with parameter server replication, like this:
The key function here is
tf.train.replica_device_setter. It allows you to run 3 different processes, called parameter servers, that store all of model variables. The largeembeddingtensor will be split across these servers like on this picture.