using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using System.Collections;
using System.Collections.Generic;
public class CartPoleAgent : Agent
{
public GameObject pole;
Rigidbody poleRB;
Rigidbody cartRB;
EnvironmentParameters m_ResetParams;
public override void Initialize()
{
poleRB = pole.GetComponent<Rigidbody>();
cartRB = gameObject.GetComponent<Rigidbody>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(gameObject.transform.localPosition.z);
sensor.AddObservation(cartRB.velocity.z);
sensor.AddObservation(pole.transform.localRotation.eulerAngles.x);
sensor.AddObservation(poleRB.angularVelocity.x);
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
Vector3 controlSignal = Vector3.zero;
controlSignal.x = actionBuffers.ContinuousActions[0];
controlSignal.z = actionBuffers.ContinuousActions[1];
var actionZ = 200f * Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);
cartRB.AddForce(new Vector3(0.0f, 0.0f, actionZ), ForceMode.Force);
float cart_z = this.gameObject.transform.localPosition.z;
float angle_x = pole.transform.localRotation.eulerAngles.x;
if(180f < angle_x && angle_x < 360f)
{
angle_x = angle_x - 360f;
}
if((-180f < angle_x && angle_x < -45f) || (45f < angle_x && angle_x < 180f))
{
SetReward(-1.0f);
EndEpisode();
}
else{
SetReward(0.1f);
}
if(cart_z < -10f || 10f < cart_z)
{
SetReward(-1.0f);
EndEpisode();
}
}
public override void OnActionReceived(float[] verctorAction)
{
var actionZ = 200f * Mathf.Clamp(verctorAction[0], -1f, 1f);
cartRB.AddForce(new Vector3(0.0f, 0.0f, actionZ), ForceMode.Force);
float cart_z = this.gameObject.transform.localPosition.z;
float angle_x = pole.transform.localRotation.eulerAngles.x;
if(180f < angle_x && angle_x < 360f)
{
angle_x = angle_x - 360f;
}
if((-180f < angle_x && angle_x < -45f) || (45f < angle_x && angle_x < 180f))
{
SetReward(-1.0f);
EndEpisode();
}
else{
SetReward(0.1f);
}
if(cart_z < -10f || 10f < cart_z)
{
SetReward(-1.0f);
EndEpisode();
}
}
public override void OnEpisodeBegin()
{
gameObject.transform.localPosition = new Vector3(0f, 0f, 0f);
pole.transform.localPosition = new Vector3(0f, 2.5f, 0f);
pole.transform.localRotation = Quaternion.Euler(0f, 0f, 0f);
poleRB.angularVelocity = new Vector3(0f, 0f, 0f);
poleRB.velocity = new Vector3(0f, 0f, 0f);
poleRB.angularVelocity = new Vector3(Random.Range(-0.1f, 0.1f), 0f, 0f);
SetResetParameters();
}
public void SetPole()
{
poleRB.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
pole.transform.localScale = new Vector3(0.4f, 2f, 0.4f);
}
public void SetResetParameters()
{
SetPole();
}
}