Trouble implementing ml-agents in a Unity snooker game

176 Views Asked by At

I'm trying to implement a Reinforcement Learning snooker player and I want to train it using ml-agents in Unity. The thing is that I am not used to this library and I am not sure how should I do it. So when I try to train it, the agent does nothing. Could you check what am I doing wrong and how can I solve it? Thanks so much. This is the main code:


using System;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using UnityEngine;
using System.Collections;
using UnityEngine.UI;


public struct Player
{
    public int score;
    public bool isNPC;
}

public class Snooker2D : Agent
{
    public GameObject cueBall;
    public Transform stick;
    public Transform ball;
    public Transform selectionFx;
    public Slider slider;

    int currentPlayer = 1;
    int nPlayers = 2;
    Player[] players = new Player[2];
    bool follow = true;
    bool doubleShot = false;
    float ballRadius = 0;
    LayerMask layerMaskBalls = 1 << 9;
    LayerMask layerMaskBallsAndWalls = 1 << 9 | 1 << 10;
    LayerMask layerMaskWalls = 1 << 10;

    RaycastHit2D hit;
    float dist = 0;
    float minDist = 0;
    float maxDist = -3;
    float forceMultiplier = 2.5f;

    Vector3 whiteBallPos = new Vector3(-4.3111f, 0, 0);

    public override void OnEpisodeBegin()
    {
        // Reset the environment
        Application.LoadLevel(Application.loadedLevel);
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        // Target and Agent positions
        sensor.AddObservation(ball.position);
        // sensor.AddObservation(stick.position);

        // All balls positions
        foreach (GameObject ball in GameObject.FindGameObjectsWithTag("Ball"))
        {
            sensor.AddObservation(ball.transform.position);
        }
    }

    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        // Actions, size = 3, the force vector
        var action = actionBuffers.ContinuousActions;
        Vector3 force = new Vector3(action[0], action[1], action[2]);

        // Apply the force vector to the game environment, for example by adding it to the velocity of the cue ball
        Shoot(force);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var continuousActionsOut = actionsOut.ContinuousActions;

        // Set the force vector using keyboard or gamepad inputs, each action is between -1 and 1
        // Each action is an axis of the force vector
        continuousActionsOut[0] = Input.GetAxis("Horizontal");
        continuousActionsOut[1] = Input.GetAxis("Vertical");
        continuousActionsOut[2] = Input.GetAxis("Jump");
    }


    void Start()
    {
        players[1].isNPC = true;

        ballRadius = ball.GetComponent<CircleCollider2D>().radius;
        minDist = -(ballRadius + ballRadius / 2);
        dist = Mathf.Clamp(maxDist / 2, maxDist, minDist);

        slider.maxValue = -maxDist;
        slider.minValue = -minDist;
        slider.value = dist + -maxDist - minDist;

        selectionFx.GetComponent<Fader>().StartFade();
    }

    void Update()
    {
        if (Input.GetKeyDown("r")) Application.LoadLevel(Application.loadedLevel);

        // PLAYER TURN
        if (follow)
        {
            Vector3 forceDir = Vector3.zero;
            // Analog Player
            if (!players[currentPlayer].isNPC)
            {
                // MOUSE: Get mouse position
                Vector3 mPos = Camera.main.ScreenToWorldPoint(new Vector3(Input.mousePosition.x, Input.mousePosition.y, 10));
                PowerAdjust();
                RotateStickAroundBall(mPos);
                ProjectTrajectory();
                forceDir = (ball.position - stick.position).normalized * -(dist - minDist - 0.02f) * forceMultiplier;
                if (Input.GetMouseButtonUp(0)) Shoot(forceDir); // && hit.collider != null)
            }
            // AI Player
            else
            {
                // TODO: Here will be the AI
                // Reward the model if there are less balls on the table
                // Apply the action to the ball
                //forceDir = new Vector3(1, 1, 1);
                //Shoot(forceDir);
                var continuousActionsOut = new ActionBuffers();
                Heuristic(continuousActionsOut);
                SetReward(1.0f / GameObject.FindGameObjectsWithTag("Ball").Length);
                if (GameObject.FindGameObjectsWithTag("Ball").Length == 1)
                {
                    // End the episode if there is only one ball left
                    EndEpisode();
                }
            }
        }

        // start following
        if (!follow)
        {
            if (AllBallsStopped() && ball.GetComponent<Rigidbody2D>().velocity.sqrMagnitude == 0.0f)
            {
                // Update player score
                // players[currentPlayer].score = UpdateScore(currentPlayer);
                // currentPlayer = (currentPlayer + 1) % nPlayers;
                follow = true;
                if (cueBall.activeSelf == false)
                {
                    SetReward(-0.1f);
                    RespawnBall();
                }
                selectionFx.GetComponent<Fader>().StartFade();
                Invoke("HideShowStick", 0.2f);
            }
        }

    } // update()
}

I tried reading about this, but I couldn't find much information.

0

There are 0 best solutions below