I have a question on this SARSA FA.
In input cell 142 I see this modified update
w += alpha * (reward - discount * q_hat_next) * q_hat_grad
where q_hat_next
is Q(S', a')
and q_hat_grad
is the derivative of Q(S, a)
(assume S, a, R, S' a'
sequence).
My question is shouldn't the update should be like this?
w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad
What is the intuition behind the modified update?
I think you are correct. I would also have expected that the update contains the TD error term, which should be
reward + discount * q_hat_next - q_hat
.For reference, this is the implementation:
And this is pseudo-code from Reinforcement Learning: An Introduction (by Sutton & Barto) (page 171):
As the implementation is TD(0),
n
is 1. Then the update in the pseudo-code can be simplified:becomes (by substituting
G == reward + discount*v(S_t+1,w))
)Or with the variable names in the original code example:
I ended up with the same update formula that you have. Looks like a bug in the non-terminal state update.
Only the terminal case (if
done
is true) should be correct because thenq_hat_next
is always 0 by definition, as the episode is over and no more reward can be gained.