Vowpal Wabbit - How to control prediction probabilities from contextual bandit model on a test sample

78 Views Asked by At

※I apologize if my English is difficult to read or if there are any mistakes, as I am currently studying the English.

I'm having trouble understanding which actions correspond to the output results even after looking at them. I assume that the order of actions in the training data corresponds to the order in the current execution results.

We are considering the following specifications:

  • There are 5 types of actions : Lv1, Lv2, Lv3, Lv4, Lv5.
  • The training data consists of 500 rows.
  • The execution environment is Python.

For example, let's say I have a train set named "learn_example" containing lines formatted as below.

learn_example = (
"Lv2:7.69:0.2 | A:7.58 B:20.4 C:56.41 D:4",  # <action:cost:probability | features> 
"Lv2:20.6:0.2 | A:3.18 B:21.3 C:56.41 D:4",
"Lv1:11.8:0.3 | A:24.19 B:22.8 C:52.41 D:5",
"Lv3:38.9:0.2 | A:33.27 B:24.1 C:53.41 D:5",
"Lv5:8.95:0.1 | A:56.48 B:24.4 C:48.35 D:5",
"Lv1:35.9:0.2 | A:46.21 B:25.6 C:49.85 D:5",
"Lv4:0.69:0.1 | A:25.81 B:22.4 C:50.21 D:5",
"Lv3:15.7:0.2 | A:13.33 B:21.2 C:51.71 D:5",
....)

I executed predictions with the following code and output the probabilities of 5 types of actions.

vw = vowpalwabbit.Workspace("--cb_explore 5 --epsilon 0.25 --cb_type dr") 

for example_data in learn_example:
    vw.learn(example_data)

sample_context = "| A:18.36 B:20.1 C:49.23 D:5"

predict_result = vw.predict(sample_context)
print(predict_result)

The execution results are as follows:

[0.20761550962924957, 0.1931363344192505, 0.19554626941680908, 0.19322560727596283, 0.21047623455524445]

Therefore, I believe the correspondence is as follows, but could you confirm if it's correct?

[Lv2:0.20761550962924957, Lv1:0.1931363344192505, Lv3:0.19554626941680908, Lv5:0.19322560727596283, Lv4:0.21047623455524445]

Is it possible to align the prediction results in the order of Lv1, Lv2, Lv3, Lv4, Lv5?

[Lv1:0.1931363344192505, Lv2:0.20761550962924957, Lv3:0.19554626941680908, Lv4:0.21047623455524445, Lv5:0.19322560727596283]
0

There are 0 best solutions below