[PyTorch로 시작하는 강화학습 입문] 8편: Actor-Critic 접근 – A2C(Advantage Actor-Critic) 기초 구현

정책기반 접근(REINFORCE)은 정책을 직접 파라미터화하고, 에피소드가 끝난 후 누적보상을 이용해 정책 그래디언트를 업데이트합니다. 이 방법은 개념적으로 간단하지만, 다음과 같은 단점이 있습니다.

  • 고분산 업데이트: 에피소드 단위로 G(Return)를 계산하므로, 긴 에피소드나 복잡한 문제에서는 분산이 매우 커져 업데이트 효율이 떨어집니다.
  • 느린 반응: 에피소드가 끝나야만 업데이트가 이루어지므로 실시간 반응이 어려움.

Actor-Critic 접근은 이러한 문제를 완화합니다. 여기서 에이전트는 두 가지 신경망(또는 하나의 공유 신경망)을 갖습니다.

  1. Actor(정책 네트워크): πθ(a|s)를 파라미터화하여, 상태에서 행동 확률분포를 출력 (정책기반)
  2. Critic(가치추정 네트워크): Vψ(s)를 파라미터화하여, 현재 상태의 가치(기대 반환)를 추정 (가치기반)

Actor는 행동을 샘플링하고, Critic은 현재 상태의 가치를 근사해서 Advantage(우위) 값을 계산합니다. Advantage = Q(s,a) - V(s) ≈ r + γV(s') - V(s)를 이용해 정책 그래디언트의 분산을 줄이고, 더 빈번하고 안정적인 업데이트가 가능합니다. 에피소드가 끝나기 전에도 배치 업데이트를 부분적으로 할 수 있어, 훨씬 빠른 피드백-학습 루프를 구성할 수 있습니다.

이번 글에서는 A2C(Advantage Actor-Critic) 기법을 단순화해서 구현해보겠습니다. A2C는 여러 워커(환경)에서 수집한 경험을 평균해 업데이트하는 방식을 가정하지만, 여기서는 단일 환경으로 기본 아이디어를 살펴봅니다.

참고자료:

  • Mnih et al., "Asynchronous Methods for Deep Reinforcement Learning"(A3C 논문): https://arxiv.org/abs/1602.01783
    A3C와 유사하나 동기화(Asynchronous)를 제거한 형태가 A2C로 알려짐
  • PPO, A2C, A3C 관련 PyTorch 예제 공식 튜토리얼 등

A2C 구현 개요

  1. 매 스텝마다 (s, a, r, s') 경험을 수집.
  2. Critic을 이용해 V(s), V(s') 값을 구해 Advantage를 계산: A = r + γV(s') - V(s)
  3. Actor 업데이트: ∇θ log π(a|s)*A를 이용해 정책 개선
  4. Critic 업데이트: 가치망에 대해 (V(s) - (r + γV(s')))² 최소화
  5. 배치로 일정 스텝 단위로 모아 업데이트하거나, 스텝마다 업데이트하는 등 다양한 변형 가능.

여기서는 단순화를 위해, 일정 수의 스텝을 모아 Advantage와 목표값을 계산한 뒤 업데이트하는 형태를 구현해봅니다.

코드 예제 (A2C 단순 구현)

아래 예제에서는 CartPole 환경에서 A2C를 간단히 구현합니다. Actor와 Critic을 하나의 신경망에 넣고, 마지막에 두 출력을 분리하는 구조를 사용합니다. 일정 스텝마다(예: n-step) 샘플링한 transition을 모아 Advantage를 계산하고 Actor-Critic 업데이트를 수행합니다.

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

###################################
# Actor-Critic 통합 네트워크
###################################
class ActorCriticNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=64):
        super(ActorCriticNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        
        # 정책(action) 출력 계층
        self.action_head = nn.Linear(hidden_size, action_dim)
        # 가치(value) 출력 계층
        self.value_head = nn.Linear(hidden_size, 1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        action_logits = self.action_head(x)
        value = self.value_head(x)
        return action_logits, value

###################################
# A2C 에이전트
###################################
class A2CAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, lr=1e-3, update_steps=5):
        self.gamma = gamma
        self.update_steps = update_steps
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.ac_net = ActorCriticNetwork(state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.ac_net.parameters(), lr=lr)
        
        # 샘플링 결과를 저장할 버퍼(짧은 n-step)
        self.states = []
        self.actions = []
        self.rewards = []
        self.done_flags = []
        
        self.action_dim = action_dim

    def select_action(self, state):
        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        action_logits, _ = self.ac_net(state_t)
        action_probs = F.softmax(action_logits, dim=-1)
        dist = torch.distributions.Categorical(action_probs)
        action = dist.sample()
        self.states.append(state)
        self.actions.append(action.item())
        return action.item()
    
    def store_reward(self, reward, done):
        self.rewards.append(reward)
        self.done_flags.append(done)

    def update(self, next_state):
        # next_state로 V(s') 계산
        if self.done_flags[-1]:
            # 종료 상태이면 V(s')=0
            next_value = 0.0
        else:
            next_state_t = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
            with torch.no_grad():
                _, next_value_t = self.ac_net(next_state_t)
                next_value = next_value_t.item()
        
        # n-step Returns 계산
        returns = []
        G = next_value
        for r, done in reversed(list(zip(self.rewards, self.done_flags))):
            if done:
                G = r  # done이면 G를 리셋
            else:
                G = r + self.gamma * G
            returns.insert(0, G)
        
        states_t = torch.FloatTensor(self.states).to(self.device)
        actions_t = torch.LongTensor(self.actions).to(self.device)
        returns_t = torch.FloatTensor(returns).to(self.device)
        
        action_logits, values_t = self.ac_net(states_t)
        values_t = values_t.squeeze(1)
        
        action_probs = F.softmax(action_logits, dim=-1)
        dist = torch.distributions.Categorical(action_probs)
        log_probs = dist.log_prob(actions_t)
        
        # Advantage = returns_t - values_t
        advantage = returns_t - values_t.detach()
        
        # Actor Loss = -logπ(a|s)*Advantage
        actor_loss = -(log_probs * advantage).mean()
        # Critic Loss = (V(s)-Returns)^2
        critic_loss = F.mse_loss(values_t, returns_t)
        
        loss = actor_loss + 0.5 * critic_loss
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 버퍼 초기화
        self.states = []
        self.actions = []
        self.rewards = []
        self.done_flags = []

def train_a2c(env_name="CartPole-v1", max_episodes=300, update_steps=5):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    agent = A2CAgent(state_dim, action_dim, update_steps=update_steps)
    
    reward_history = []
    state = env.reset()
    for ep in range(max_episodes):
        total_reward = 0
        done = False
        while not done:
            action = agent.select_action(state)
            next_state, reward, done, info = env.step(action)
            agent.store_reward(reward, done)
            
            state = next_state
            total_reward += reward
            
            # 일정 스텝마다 업데이트 수행 (또는 에피소드 종료 시 수행)
            if len(agent.rewards) >= agent.update_steps or done:
                agent.update(next_state)
        
        reward_history.append(total_reward)
        if (ep+1) % 20 == 0:
            avg_reward = np.mean(reward_history[-20:])
            print(f"Episode {ep+1}, Avg Reward(last 20): {avg_reward:.2f}")
        
        # 에피소드 끝났으니 다음 에피소드 시작
        state = env.reset()

    env.close()

if __name__ == "__main__":
    train_a2c()

코드 해설

  • ActorCriticNetwork: 하나의 신경망에서 정책(action_logits)과 가치(value) 모두 출력. Actor와 Critic 역할을 동시 수행.
  • A2CAgent:
    • 일정 스텝(update_steps) 동안 (s,a,r,done) 수집
    • update() 호출 시 마지막 상태의 가치로부터 backward n-step return 계산
    • Advantage = Return - V(s) 이용해 Actor, Critic 동시 업데이트
  • Critic의 V(s)가 Advantage 추정에 도움을 주어, REINFORCE보다 변동성이 낮은 업데이트 가능

마무리

이번 글에서는 Actor-Critic 접근의 기본 개념과 A2C 구현 예제를 다뤘습니다. Actor-Critic은 정책기반과 가치기반의 장점을 결합해, 더 안정적이고 효율적인 학습을 지원합니다. 이후에는 A3C, PPO, SAC 등 다양한 Actor-Critic 변형 기법들로 확장할 수 있으며, 이들은 대규모 병렬 환경, 연속 동작 공간, 복잡한 태스크에도 좋은 성능을 보입니다.

다음 글에서는 PPO나 다른 고급 Actor-Critic 알고리즘을 살펴보며, 강화학습 알고리즘 선택 및 실제 응용 시 고려사항을 다룰 수 있습니다.

반응형