SARSA value approximation for Cart Pole

436 Views Asked by At

I have a question on this SARSA FA.

In input cell 142 I see this modified update

w += alpha * (reward - discount * q_hat_next) * q_hat_grad

where q_hat_next is Q(S', a') and q_hat_grad is the derivative of Q(S, a) (assume S, a, R, S' a' sequence).

My question is shouldn't the update should be like this?

w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad

What is the intuition behind the modified update?

1

There are 1 best solutions below

0
On

I think you are correct. I would also have expected that the update contains the TD error term, which should be reward + discount * q_hat_next - q_hat.

For reference, this is the implementation:

if done: # (terminal state reached)
   w += alpha*(reward - q_hat) * q_hat_grad
   break
else:
   next_action = policy(env, w, next_state, epsilon)
   q_hat_next = approx(w, next_state, next_action)
   w += alpha*(reward - discount*q_hat_next)*q_hat_grad
   state = next_state

And this is pseudo-code from Reinforcement Learning: An Introduction (by Sutton & Barto) (page 171):

enter image description here

As the implementation is TD(0), n is 1. Then the update in the pseudo-code can be simplified:

w <- w + a[G - v(S_t,w)] * dv(S_t,w)

becomes (by substituting G == reward + discount*v(S_t+1,w)))

w <- w + a[reward + discount*v(S_t+1,w) - v(S_t,w)] * dv(S_t,w)

Or with the variable names in the original code example:

w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad

I ended up with the same update formula that you have. Looks like a bug in the non-terminal state update.

Only the terminal case (if done is true) should be correct because then q_hat_next is always 0 by definition, as the episode is over and no more reward can be gained.