본문 바로가기

프로그래밍/Unity-ML

Unity-ML 강화학습 예제 작성

작성한 코드 저장소: https://bitbucket.org/hucce/mltest

참고한 블로그:  https://puzzleleaf.tistory.com/209

 

[Unity] ML-Agents로 간단한 강화학습 예제 만들어보기

Unity-Technologies/ml-agents Unity Machine Learning Agents Toolkit. Contribute to Unity-Technologies/ml-agents development by creating an account on GitHub. github.com ML-Agents의 사용방법에 조금 익..

puzzleleaf.tistory.com

강화학습 예제로 많이 쓰이는 것이 플레이어를 구체로, 타깃객체를 찾아가는 예제를 작성했다.

보통 가장 많이 쓰는 모델은 플로어 중앙에 생성되는 구체와 랜덤한 위치에 생성되는 타깃 큐브를 찾아가는 예제를 많이 사용한다.

나는 참고한 블로그에 맞춰 모서리에 있는 4개의 큐브 중 랜덤으로 하나의 큐브를 타깃으로 지정하고 이를 찾아가는 것을 목표로 만들었다.

Unity ML의 기본구조

참고한 블로그: https://mindrich.tistory.com/29?category=755803

 

Unity ML 기초 1. ML 개요 - ML(머신러닝) Unity

★ 머신러닝(ML)이란 ★ -1. 기계학습 또는 머신러닝은 인공지능의 한 분야이다 -2. 컴퓨터가 학습할 수 있도록 하는 알고리즘과 기술을 개발하는 분야 -3. 대량의 데이터나 알고리즘을 통해 '학습' 을 시켜 수행..

mindrich.tistory.com

기본적으로 3층의 구조를 가지고 있음 Academy(아카데미), Brain(브레인), Agent(에이전트) 아카데미는 기본적인 환경을 설정, 브레인은 에이전트의 액션을 결정한다. 에이전트는 환경 내에서 액션을 취한다.

기초적인 구조를 알고 강화학습을 작성해본다.

1. 기본적인 객체들을 만듬

2. 아카데미의 작성

아카데미용 빈 오브젝트를 만들고 아카데미 코드 작성 또는 Basic Academy 컴포넌트를 추가한다. 설정은 기본설정에서 바꾸지 사용함

트레이닝 설정은 학습할 때의 설정(프레임은 -1은 무제한임)

3. 에이전트 구성

에이전트가 될 구체에 에이전트 에이전트를 구현한다. 아래는 구현한 코드이다.

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;

public class BallAgent : Agent
{
    private Rigidbody playerRigidbody;  // 플레이어의 리지드바디

    [SerializeField]
    private Transform[] target;            // 목표의 위치

    private int random;

    public float moveForce = 10f;       // 플레이어를 이동시키는 힘

    private bool catchTarget = false;   // 플레이어가 타겟을 잡았는지 여부
    private bool isDead = false;        // 플레이어가 플로어에서 벗어났는지 여부

    public override void InitializeAgent()
    {
        playerRigidbody = GetComponent<Rigidbody>();
        RandomTarget();
    }

    // 에이전트가 죽었다가 다시 시작될 때의 함수
    public override void AgentReset()
    {
        // 플레이어는 훈련 유닛의 위치를 기준으로 플로어 위에 랜덤의 좌표에 위치
        transform.position = new Vector3(0, 0.79f, 0);

        isDead = false;                             // 사망상태 초기화
        catchTarget = false;
        playerRigidbody.velocity = Vector3.zero;    // 플레이어 속도 초기화

        RandomTarget(); // 타겟 초기화
    }

    // 에이전트가 수집하는 데이터 x, z 좌표에 대한 위치 속도
    public override void CollectObservations()
    {
        Vector3 distance = target[random].transform.position - this.transform.position;

        AddVectorObs(distance.x);
        AddVectorObs(distance.z);

        AddVectorObs(playerRigidbody.velocity.x);
        AddVectorObs(playerRigidbody.velocity.z);
    }

    // 브레인이 에이전트에게 내리는 지시
    public override void AgentAction(float[] vectorAction, string textAction)
    {
        AddReward(-0.01f); // 아무런 액션을 하지 않는 것을 방지 하기 위한 패널티

        float horizontalInput = vectorAction[0];
        float verticalInput = vectorAction[1];

        // 플레이어의 속도
        playerRigidbody.AddForce(horizontalInput * moveForce, 0f, verticalInput * moveForce);

        // 타겟을 잡으면 보상하고 타겟 리셋
        if (catchTarget)
        {
            AddReward(1.0f);
            Done();
        }
        // 플레이어가 플로어 밖으로 떨어지면 종료
        else if (isDead)
        {
            AddReward(-1.0f);
            Done();
        }
    }

    private void OnTriggerEnter(Collider _collider)
    {
        if (_collider.CompareTag("goal"))
        {
            catchTarget = true;
        }
        else if (_collider.CompareTag("Finish"))
        {
            isDead = true;
        }
    }

    private void RandomTarget()
    {
        for (int i = 0; i < target.Length; i++)
        {
            target[i].tag = "wall";
            target[random].transform.localScale = new Vector3(1, 1, 1);
        }

        random = Random.Range(0, 4);
        target[random].tag = "goal";
        target[random].transform.localScale = new Vector3(1, 2, 1);
    }
}

보상과 패널티는 빠르게 타깃을 찾기 위해 계속 패널티를 받고, 플로어 밖으로 떨어지면 -1, 타깃을 잡으면 +1을 받도록 설정했다.

에이전트는 에이전트클래스에서 상속을 받는데 버추얼 메소드이기 때문에 오버라이드해서 메소드를 사용한다. 주로 AgentAction, CollectObservations, AgentReset, InitializeAgent는 많이 사용되는 메소드이며, 학습종료는 Done을 이용해서 학습을 종료한다.

다른 예제를 보았을 때 원래의 유니티의 제공해주는 함수대신 Start나 Awake, 또는 자신만의 함수를 만들어 사용하는 것도 가능하나 ML의 단계부여 등 여러 요소를 추가할 때를 생각하면 가능하면 본래 의도대로 하는 것이 좋아보인다.

해당 하는 스크립트 컴포넌트를 작성 후 구체에 컴포넌트를 추가한다.

4. 브레인 구성

학습에 필요한 브레인은 러닝 브레인으로 러닝브레인으로만 할 수 있지만, 기존 게임을 이용하는 경우를 위해서 플레이어 브레인을 사용했다. 플레이어 브레인은 주로 Input 값을 받아 조종 되는 플레이어를 브레인이 통제하기 위해서 사용되는 브레인이다.

해당하는 플레이어 브레인을 다음과 같이 설정했다. 타입은 두가지가 있는데 Discrete 또는 Continuous가 있는데, 2가지는 정수형이냐 또는 실수형 인지가 큰차이이며, 이 예제에서는 Continuous로 진행하였고, 받아오는 액션의 값은 두축을 기반으로 하도록 하였다.

마지막으로 플레이어브레인과 러닝브레인을 해당하는 에이전트와 아카데미에 추가하였다.

이제 작성이 완료된 예제를 테스트하자.

'프로그래밍 > Unity-ML' 카테고리의 다른 글

Unity-ML 카메라 환경  (0) 2019.10.15
Unity-ML의 구조  (0) 2019.10.01
Unity-ML 비디오 레코더  (0) 2019.09.26
Unity-ML 강화학습 테스트 및 수정  (0) 2019.09.18
UNITY-ML 설치 및 세팅  (1) 2019.09.18