[PyTorch로 시작하는 강화학습 입문] 9편: 안정적인 정책 업데이트 – PPO(Proximal Policy Optimization) 소개 및 구현

A2C까지는 정책과 가치를 동시에 학습하는 Actor-Critic 방법론의 기본을 익혔습니다. 그러나 A2C나 A3C, TRPO 같은 알고리즘들은 정책 업데이트 과정에서 제한이 명확하지 않아, 큰 갱신으로 인한 성능 퇴보가 발생할 수 있습니다.

PPO(Proximal Policy Optimization)는 이를 개선하기 위해 다음과 같은 핵심 아이디어를 제안합니다.

  • 정책 업데이트 시, 새로운 정책과 기존 정책의 차이를 '클리핑(clipping)'하여, 정책이 한 번에 크게 바뀌지 않도록 제약
  • 이로써 안정적인 학습이 가능해지고, 복잡한 수학적 보증이 필요한 TRPO보다 구현이 단순하며, 널리 사용되는 SOTA급 RL 알고리즘으로 자리매김

핵심 개념:

  • Probability Ratio (r):
    r(θ) = π_θ(a|s) / π_θ_old(a|s)
    현재 정책과 이전 정책의 행동 확률 비율
  • Clipped Objective:
    L^{CLIP}(θ) = E[ min(r(θ)*A, clip(r(θ), 1-ε, 1+ε)*A ) ]
    여기서 A는 Advantage, ε는 작은 클리핑 범위(예: 0.1 ~ 0.2)
  • 클립으로 r(θ)를 1±ε 범위에 묶어, 업데이트 시 너무 극단적으로 정책이 변하지 않게 함

또한 PPO는 Advantage 추정(일반화된 Advantage 추정 GAE 등), 배치 학습, 여러 Epoch 동안 같은 데이터로 정책 업데이트 등 다양한 실전 테크닉과 잘 어울려서 뛰어난 성능을 보여줍니다.

여기서는 기본적인 PPO 알고리즘 흐름을 간략히 구현해보겠습니다. 실제로 PPO를 최적으로 구현하려면 GAE, 여러 병렬 환경, 마스크 처리, 배치/미니배치 반복 업데이트 등의 최적화가 필요하지만, 이 글에서는 개념 전달과 기본 구조를 익히는 데 초점을 둡니다.

추가 참고자료:

PPO 구현 개요

  1. 일정 스텝동안 (s,a,r,s',done) 수집
  2. Critic을 이용해 V(s) 추정, Advantage = G - V(s) 계산 (여기서는 n-step Return 사용, 단순화)
  3. 원래 정책 π_old를 고정한 상태에서 새로운 정책 π_θ에 대해 Ratio = π_θ(a|s)/π_old(a|s) 계산
  4. Clipped Objective 이용해 Actor 업데이트
  5. Critic은 (V(s)-Return)^2 최소화로 업데이트
  6. π_old ← π_θ 동기화 후 다음 iteration

예제 코드 (간단 PPO)

아래 예제는 한 번에 일정 스텝만큼 샘플링한 후, 그 데이터를 이용해 다중 epoch 업데이트하는 단순 PPO 구현을 보여줍니다. 실제로 PPO는 배치 사이즈, epoch 횟수, mini-batch slicing 등 다양한 하이퍼파라미터 튜닝이 필요하지만, 여기서는 기본 흐름만 파악합니다.

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

###################################
# Actor-Critic 네트워크 (PPO용)
###################################
class PPOActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=64):
        super(PPOActorCritic, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        
        self.action_head = nn.Linear(hidden_size, action_dim)
        self.value_head = nn.Linear(hidden_size, 1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.action_head(x)
        value = self.value_head(x)
        return logits, value

    def get_action_value(self, state):
        # 상태 -> 행동 확률, 상태 가치, 행동 샘플 및 log_prob 반환
        logits, value = self.forward(state)
        action_probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(action_probs)
        action = dist.sample()
        return action, dist.log_prob(action), value, action_probs

    def get_logprob_value(self, state, action):
        logits, value = self.forward(state)
        action_probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(action_probs)
        return dist.log_prob(action), value

###################################
# PPO 에이전트
###################################
class PPOAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, lr=3e-4, clip_epsilon=0.2, k_epochs=4, rollout_steps=2048):
        self.gamma = gamma
        self.clip_epsilon = clip_epsilon
        self.k_epochs = k_epochs
        self.rollout_steps = rollout_steps
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.ac = PPOActorCritic(state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.ac.parameters(), lr=lr)
        
        # 롤아웃 저장용 버퍼
        self.states = []
        self.actions = []
        self.logprobs = []
        self.rewards = []
        self.done_flags = []
        self.values = []

    def select_action(self, state):
        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            action, logprob, value, _ = self.ac.get_action_value(state_t)
        self.states.append(state)
        self.actions.append(action.item())
        self.logprobs.append(logprob.item())
        self.values.append(value.item())
        return action.item()

    def store_reward(self, reward, done):
        self.rewards.append(reward)
        self.done_flags.append(done)

    def compute_returns_advantages(self, next_state):
        # GAE 등 사용 가능하지만 여기서는 단순 n-step Return
        # next_value가 done이면 0, 아니면 value(s')
        if self.done_flags[-1]:
            next_value = 0
        else:
            state_t = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
            with torch.no_grad():
                _, next_value_t = self.ac.forward(state_t)
            next_value = next_value_t.item()
        
        returns = []
        G = next_value
        for r, done, val in reversed(list(zip(self.rewards, self.done_flags, self.values))):
            if done:
                G = r
            else:
                G = r + self.gamma * G
            returns.insert(0, G)
        
        returns = np.array(returns)
        values = np.array(self.values)
        advantages = returns - values
        # 정규화(Optional)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-9)
        return returns, advantages

    def update(self, next_state):
        returns, advantages = self.compute_returns_advantages(next_state)
        
        states_t = torch.FloatTensor(self.states).to(self.device)
        actions_t = torch.LongTensor(self.actions).to(self.device)
        old_logprobs_t = torch.FloatTensor(self.logprobs).to(self.device)
        returns_t = torch.FloatTensor(returns).to(self.device)
        advantages_t = torch.FloatTensor(advantages).to(self.device)

        # 여러 epoch 업데이트
        for _ in range(self.k_epochs):
            logprobs, values_t = self.ac.get_logprob_value(states_t, actions_t)
            ratio = torch.exp(logprobs - old_logprobs_t)
            
            surr1 = ratio * advantages_t
            surr2 = torch.clamp(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon) * advantages_t
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = F.mse_loss(values_t.squeeze(), returns_t)
            loss = actor_loss + 0.5 * critic_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        
        # 버퍼 정리
        self.states = []
        self.actions = []
        self.logprobs = []
        self.rewards = []
        self.done_flags = []
        self.values = []

def train_ppo(env_name="CartPole-v1", max_episodes=300):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    agent = PPOAgent(state_dim, action_dim)
    reward_history = []
    state = env.reset()
    episode_rewards = 0
    for ep in range(max_episodes):
        for step in range(agent.rollout_steps):
            action = agent.select_action(state)
            next_state, reward, done, info = env.step(action)
            agent.store_reward(reward, done)
            state = next_state
            episode_rewards += reward

            if done:
                reward_history.append(episode_rewards)
                episode_rewards = 0
                state = env.reset()

        # 수집한 rollout_steps만큼의 데이터로 업데이트
        agent.update(next_state)
        
        if (ep+1) % 20 == 0:
            avg_reward = np.mean(reward_history[-20:]) if len(reward_history)>20 else np.mean(reward_history)
            print(f"Episode {ep+1}, Avg Reward(last 20): {avg_reward:.2f}")
    
    env.close()

if __name__ == "__main__":
    train_ppo()

코드 해설

  • PPOActorCritic: Actor-Critic 구조로 행동 확률과 가치 모두 계산.
  • PPOAgent:
    • 일정 스텝(rollout_steps) 동안 경험 수집
    • update() 시 이전 정책에 대한 log_prob(old_logprobs)를 바탕으로 현재 정책과의 ratio 계산
    • 클리핑된 목표함수로 Actor 업데이트, Critic은 V(s)-Return 최소화
    • 여러 번(Epoch) 같은 데이터로 업데이트 가능 → 데이터 효율성 상승
  • 실제 PPO는 GAE(Generalized Advantage Estimation) 사용, batch로 나누어 mini-batch 업데이트, 파라미터 튜닝 등 최적화 필요.

마무리

이번 글에서는 PPO 알고리즘의 핵심 아이디어를 소개하고, 기본 구현 예제를 통해 안정적인 정책 업데이트 기법을 살펴보았습니다. PPO는 현재 강화학습 분야에서 가장 널리 쓰이는 알고리즘 중 하나로, 다양한 환경(Atari, MuJoCo, etc.)에서 우수한 성능을 보입니다.

앞으로는 PPO를 더 개선하거나, 고급 RL 알고리즘(SAC, TD3, MPO 등)을 다루거나, 실제 응용 사례(로보틱스, 게임 AI)로 확장하는 과정을 통해 RL 능력을 더욱 심화할 수 있습니다.

반응형