I've implemented Proximal Policy Optimization (PPO) from scratch for a discrete environment. The algorithm involves initializing a Policy Network, State-Value Function, and Action-Value Function (as neural network function approximators). The process includes policy rollouts where the Policy Network and Value Functions are updated using gradient descent and the Advantage Function is computed at each time step.
Here's a brief overview of my implementation:
Initialization: Random parameters for the Policy Network, State-Value, and Action-Value Functions. Policy Rollouts: Feedforwarding the Policy Network and Value Functions at each time step (t), computing Value Functions using Gradient Descent, and calculating the Advantage Function. Surrogate Objective: Computing the clipped surrogate objective at each time step (t). Training: Using Stochastic Gradient Descent (SGD) to update the Policy Network based on the sum of clipped surrogate values over the entire episode. The loss is the negative of the expected surrogate objective. Despite following these steps, my implementation is not learning, and I'm struggling to identify the issue. I've considered that the neural network configurations might be incorrect, but given the simplicity of the discrete environment, I'm unsure if that's the root cause.
Here are the configurations for each neural network:
Policy Network (Action Output):
Input Dimension: 4 Hidden Dimension: 32 Output Dimension: 2 Learning Rate: 0.0001 Value Functions (Q(s,a), V(s)):
Input Dimension: 4 Hidden Dimension: 32 Output Dimension: 1 Learning Rate: 0.0001
Number of Episodes: 500 Episode Steps: 100 Clipping Parameter (Epsilon): 0.2
Here is the full code, it is around 200 lines, easy to understand: https://github.com/BernardoOlisan/PPO-Clip/blob/main/ppo.py
I experimented with increasing the number of episodes to 500, but unfortunately, I didn't observe any noticeable improvement in the learning process or convergence of my Proximal Policy Optimization implementation. Despite the extended training, the algorithm still appears to struggle, and I'm uncertain about the root cause of this issue.
I'm seeking guidance on potential pitfalls, misconfigurations, or other factors that might hinder the learning process. Any insights or suggestions would be greatly appreciated. Thank you!