Trouble understanding how exploration happens in Vowpal Wabbit Contextual Bandit

134 Views Asked by At

I'm currently building a contextual bandit to recommend actions to users on our website.

I'm using the --cb_explore_adf model because the actions can change based on the user context. Such as we wouldn't recommend the user to sign up if they are already signed up.

vw = Workspace(f"--cb_explore_adf --cb_type mtr -q PA --quiet --epsilon 0.3")

Example of a data point to run a prediction would be:

shared |Page pageViewCount:1 videoViewCount:5 language=en user_nation=US page_section=sports time_on_site:3.467392051674073
|Action a=create_oid
|Action a=recommend_content
|Action a=favorites
|Action a=download_app
|Action a=do_nothing
|Action a=survey

So when the model runs the predict on the above, we would get something like:

[0.03333333507180214, 0.03333333507180214, 0.8333333730697632, 0.03333333507180214, 0.03333333507180214, 0.03333333507180214]

What confuses me is where does the explore part of epsilon greedy happen? If I did exploitation, it would take the 3rd action, but I'm not quite sure how to apply this?

I've been searching around but I can't find specific details of how the algorithm works with the output and the best way to account for these. So choose the best action 70% of the time and explore 30% of the time.

1

There are 1 best solutions below

1
On

--cb_explore_adf is doing the following internally:

  1. predict cost of every action
  2. generating vector of probabilities based on the vector of costs.

Step 2) is exploration. Epsilon-greedy is assigning (1-epsilon+epsilon/n) to the top cost action and epsilon/n to the others. The output of vw is this probabilities vector and you need to do sampling by yourself in order to apply it to real scenario. Something like:

pmf = vw.predict(...)
chosen_action = np.random.choice(np.arange(len(pmf)), p=pmf)

In your case there is a mismatch between command line (".. --epsilon 0.3") and distribution (... 0.8(3) ..., which seems like output of --epsilon 0.2), but maybe you have copy pasted it from 2 different runs?