So I've used following code to implement Q-learning in Unity:
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
namespace QLearner
{
public class QLearnerScript
{
List<float[]> QStates; // Q states over time
List<float[]> QActions; // Q actions over time
float[] initialState;
int initialActionIndex;
float[] outcomeState;
float outcomeActionValue;
bool firstIteration;
int possibleActions;
float learningRate; // denoted by alpha
float discountFactor; // denoted by gamma
float simInterval;
System.Random r = new System.Random();
public int main(float[] currentState, float reward)
{
QLearning(currentState, reward);
// Applies a sim interval and rounds
initialState = new float[2] {(float)Math.Round((double)currentState[0] / simInterval) * simInterval , (float)Math.Round((double)currentState[1] / simInterval) * simInterval};
firstIteration = false;
int actionIndex = r.Next(0, possibleActions);
bool exists = false;
if(QStates.Count > 0)
{
for(int i = 0; i < QStates.Count; i++)
{
float[] state = QStates.ElementAt(i);
float[] actions = QActions.ElementAt(i);
if(state[0] == initialState[0] && state[1] == initialState[1])
{
exists = true;
initialActionIndex = Array.IndexOf(actions, MaxFloat(actions));
return initialActionIndex;
}
}
}
if(!exists)
{
float[] actionVals = new float[possibleActions];
for (int i = 0; i < possibleActions; i++)
{
actionVals[i] = 0f;
}
QStates.Add( initialState);
QActions.Add(actionVals);
}
initialActionIndex = actionIndex;
return initialActionIndex;
}
public QLearnerScript(int possActs)
{
QStates = new List<float[]>();
QActions = new List<float[]>();
possibleActions = possActs;
learningRate = .5f; // Between 0 and 1
discountFactor = 1f;
simInterval = 1f;
firstIteration = true;
}
public void QLearning(float[] outcomeStateFeed, float reward)
{
if(!firstIteration)
{
outcomeState = new float[2] {(float)Math.Round((double)outcomeStateFeed[0] / simInterval) * simInterval , (float)Math.Round((double)outcomeStateFeed[1] / simInterval) * simInterval};
bool exists = false;
for(int i = 0; i < QStates.Count; i++)
{
float[] state = QStates.ElementAt(i);
float[] actions = QActions.ElementAt(i);
if(state[0] == outcomeState[0] && state[1] == outcomeState[1])
{
exists = true;
outcomeActionValue = MaxFloat(actions);
}
}
for(int i = 0; i < QStates.Count; i++)
{
float[] state = QStates.ElementAt(i);
float[] actions = QActions.ElementAt(i);
if(state[0] == initialState[0] && state[1] == initialState[1])
{
if(exists)
{
actions[initialActionIndex] += learningRate * (reward + discountFactor * outcomeActionValue - actions[initialActionIndex]);
}
if(!exists)
{
actions[initialActionIndex] += learningRate * (reward + discountFactor * 0f - actions[initialActionIndex]);
}
}
}
}
}
public int getQtableCount()
{
return QStates.Count;
}
float MaxFloat(float[] numbers)
{
float max = numbers[0];
for (int i = 0; i < numbers.Length; i++)
if (max < numbers[i])
{
max = numbers[i];
}
return max;
}
}
}
Which works fine with my environment. However, I'm also trying to implement SARSA as to test the two algorithms against each other. I know that Q-learning is off-policy, while SARSA is on-policy, meaning I have to implement a policy to get the next action instead of simply calling
MaxFloat(actions)
However the actual implementation of this confuses me, how would I modify my script to include this policy?
With SARSA, the name of the algorithm is also the algorithm: you save a state, an action, the reward, and the next state and action, then use that information to perform the update.
You need to compute the update at the point when you have not just the current state and the reward for it, but at the point at which you have the previous state, the reward for the previous state, and the current state. SARSA would use the current state, while Q-Learning would replace it with the greedy policy's prediction.