I'm trying to implement a Reinforcement Learning snooker player and I want to train it using ml-agents in Unity. The thing is that I am not used to this library and I am not sure how should I do it. So when I try to train it, the agent does nothing. Could you check what am I doing wrong and how can I solve it? Thanks so much. This is the main code:
using System;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using UnityEngine;
using System.Collections;
using UnityEngine.UI;
public struct Player
{
public int score;
public bool isNPC;
}
public class Snooker2D : Agent
{
public GameObject cueBall;
public Transform stick;
public Transform ball;
public Transform selectionFx;
public Slider slider;
int currentPlayer = 1;
int nPlayers = 2;
Player[] players = new Player[2];
bool follow = true;
bool doubleShot = false;
float ballRadius = 0;
LayerMask layerMaskBalls = 1 << 9;
LayerMask layerMaskBallsAndWalls = 1 << 9 | 1 << 10;
LayerMask layerMaskWalls = 1 << 10;
RaycastHit2D hit;
float dist = 0;
float minDist = 0;
float maxDist = -3;
float forceMultiplier = 2.5f;
Vector3 whiteBallPos = new Vector3(-4.3111f, 0, 0);
public override void OnEpisodeBegin()
{
// Reset the environment
Application.LoadLevel(Application.loadedLevel);
}
public override void CollectObservations(VectorSensor sensor)
{
// Target and Agent positions
sensor.AddObservation(ball.position);
// sensor.AddObservation(stick.position);
// All balls positions
foreach (GameObject ball in GameObject.FindGameObjectsWithTag("Ball"))
{
sensor.AddObservation(ball.transform.position);
}
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// Actions, size = 3, the force vector
var action = actionBuffers.ContinuousActions;
Vector3 force = new Vector3(action[0], action[1], action[2]);
// Apply the force vector to the game environment, for example by adding it to the velocity of the cue ball
Shoot(force);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var continuousActionsOut = actionsOut.ContinuousActions;
// Set the force vector using keyboard or gamepad inputs, each action is between -1 and 1
// Each action is an axis of the force vector
continuousActionsOut[0] = Input.GetAxis("Horizontal");
continuousActionsOut[1] = Input.GetAxis("Vertical");
continuousActionsOut[2] = Input.GetAxis("Jump");
}
void Start()
{
players[1].isNPC = true;
ballRadius = ball.GetComponent<CircleCollider2D>().radius;
minDist = -(ballRadius + ballRadius / 2);
dist = Mathf.Clamp(maxDist / 2, maxDist, minDist);
slider.maxValue = -maxDist;
slider.minValue = -minDist;
slider.value = dist + -maxDist - minDist;
selectionFx.GetComponent<Fader>().StartFade();
}
void Update()
{
if (Input.GetKeyDown("r")) Application.LoadLevel(Application.loadedLevel);
// PLAYER TURN
if (follow)
{
Vector3 forceDir = Vector3.zero;
// Analog Player
if (!players[currentPlayer].isNPC)
{
// MOUSE: Get mouse position
Vector3 mPos = Camera.main.ScreenToWorldPoint(new Vector3(Input.mousePosition.x, Input.mousePosition.y, 10));
PowerAdjust();
RotateStickAroundBall(mPos);
ProjectTrajectory();
forceDir = (ball.position - stick.position).normalized * -(dist - minDist - 0.02f) * forceMultiplier;
if (Input.GetMouseButtonUp(0)) Shoot(forceDir); // && hit.collider != null)
}
// AI Player
else
{
// TODO: Here will be the AI
// Reward the model if there are less balls on the table
// Apply the action to the ball
//forceDir = new Vector3(1, 1, 1);
//Shoot(forceDir);
var continuousActionsOut = new ActionBuffers();
Heuristic(continuousActionsOut);
SetReward(1.0f / GameObject.FindGameObjectsWithTag("Ball").Length);
if (GameObject.FindGameObjectsWithTag("Ball").Length == 1)
{
// End the episode if there is only one ball left
EndEpisode();
}
}
}
// start following
if (!follow)
{
if (AllBallsStopped() && ball.GetComponent<Rigidbody2D>().velocity.sqrMagnitude == 0.0f)
{
// Update player score
// players[currentPlayer].score = UpdateScore(currentPlayer);
// currentPlayer = (currentPlayer + 1) % nPlayers;
follow = true;
if (cueBall.activeSelf == false)
{
SetReward(-0.1f);
RespawnBall();
}
selectionFx.GetComponent<Fader>().StartFade();
Invoke("HideShowStick", 0.2f);
}
}
} // update()
}
I tried reading about this, but I couldn't find much information.