Ways to speed up Monte Carlo Tree Search in a model-based RL task

1.3k Views Asked by At

This area is still very new to me, so forgive me if I am asking dumb questions. I'm utilizing MCTS to run a model-based reinforcement learning task. Basically I have an agent foraging in a discrete environment, where the agent can see out some number of spaces around it (I'm assuming perfect knowledge of its observation space for simplicity, so the observation is the same as the state). The agent has an internal transition model of the world represented by an MLP (I'm using tf.keras). Basically, for each step in the tree, I use the model to predict the next state given the action, and I let the agent calculate how much reward it would receive based on the predicted change in state. From there it's the familiar MCTS algorithm, with selection, expansion, rollout, and backprop.

Essentially, the problem is that this all runs prohibitively slowly. From profiling my code, I notice that a lot of time is spent doing the rollout, likely I imagine because the NN needs to be consulted many times and takes some nontrivial amount of time for each prediction. Of course, I can probably stand to clean up my code to make it run faster (e.g. better vectorization), but I was wondering:

  1. Are there ways to speed up/work around the traditional random walk done for rollout in MCTS?
  2. Are there generally other ways to speed up MCTS? Does it just not mix well with using an NN in terms of runtime?

Thanks!

1

There are 1 best solutions below

0
On

I am working on a similar problem and so far the following have helped me:

  1. Make sure you are running tensorflow on you GPU (You will have to install CUDA)
  2. Estimate how many steps into the future your agent needs to calculate to still get good results
  3. (The one I am currently working on) parallelize