[PyTorch로 시작하는 강화학습 입문] 6편: 우선순위 경험 리플레이(Prioritized Experience Replay)로 샘플링 효율 개선

 

기존 DQN에서는 모든 경험을 동일한 확률로 샘플링합니다. 그러나 강화학습에서는 특정 경험(transition)이 학습 초기에는 별로 도움이 안 되지만, 나중에 정책이 개선되면서 가치가 달라지거나, 에이전트가 특정 상황에서 큰 TD 오차(Temporal-Difference Error)를 낼 경우 그 경험이 정책 개선에 더 크게 기여할 수 있습니다.

우선순위 경험 리플레이(PER)의 핵심 아이디어는 TD 오차가 큰(즉, 현재 네트워크의 예측과 실제 타겟 간 차이가 큰) 경험을 더 자주 샘플링하는 것입니다. 이를 통해 에이전트는 정책 개선에 유용한 경험을 빠르게 재학습하고, 경험 데이터 활용 효율을 높일 수 있습니다.

참고자료:

Prioritized Replay 개념 정리

PER에서는 각 경험에 우선순위(priority)를 할당하고, 우선순위가 높은 경험을 샘플링할 확률을 높입니다. 우선순위는 주로 TD 오차의 절댓값 |δ|에 기반하며, TD 오차가 클수록 그 경험을 더 자주 샘플링합니다.

일반적으로 우선순위 p_i = (|δ_i| + ε)^\alpha 형태로 정의하며,

  • ε는 TD 오차가 0이라도 우선순위가 0이 되지 않게 하는 작은 상수
  • α는 우선순위 적용 정도를 조절하는 하이퍼파라미터(0이면 기존과 동일하게 균등 샘플링)

또한, 우선순위 기반 샘플링은 샘플링 편향을 일으킬 수 있으므로, 중요도 가중치(IS weight)를 사용해 학습 시 업데이트를 보정하는 방법을 적용합니다. 이것은 β라는 파라미터를 사용해 점진적으로 중요도 샘플링 보정을 강화합니다.

정리하면, PER에서는:

  1. 경험을 저장할 때 TD 오차에 따른 우선순위를 할당
  2. 우선순위를 기반으로 경험을 샘플링 (우선순위 높은 경험 더 많이 샘플링)
  3. 샘플링된 경험을 학습할 때 IS weight로 업데이트 보정
  4. 학습 후 TD 오차 재계산, 우선순위 갱신

구현 개요

우선순위 경험 리플레이를 구현하려면 기존 ReplayBuffer를 다음과 같이 확장해야 합니다.

  • Segment Tree나 Fenwick Tree(비트 트리) 등을 이용해 우선순위 합을 빠르게 관리하고, 샘플링할 경험을 O(log N)으로 찾는 방법이 일반적입니다. 여기서는 단순화를 위해 Segment Tree 기반 구현 아이디어를 간략히 보여주겠습니다.
  • 각 경험에 대한 우선순위를 저장하고, 우선순위 합을 트리 구조로 관리하면, 랜덤 숫자를 하나 뽑아서 우선순위 합을 기준으로 어느 경험을 선택해야 할지 효율적으로 결정할 수 있습니다.
  • 학습 시 TD 오차를 이용해 우선순위를 업데이트하고, 중요도 가중치(IS weight)를 계산해 손실 계산 시 반영합니다.

여기서는 코드 길이와 복잡성을 고려해, 핵심 아이디어만 담은 예제를 간략히 제시하겠습니다. 실제 프로젝트에서는 Segment Tree 구조를 좀 더 정교하게 구현해야 합니다.

간단한 PER 예제 코드 (기본 아이디어)

아래 예제는 Segment Tree를 단순화한 형태로 구현한 PrioritizedReplayBuffer 클래스를 보여줍니다. 여기서는 Double DQN 코드 기반으로 수정 가능하며, 완전한 최적화보다는 개념 이해에 초점을 둡니다.

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
from collections import deque

###################################
# Segment Tree 유사 구조 (단순화 예)
###################################
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity)  # 간단 구현: 리프 노드에 우선순위 저장, 상위 노드에 합 저장
        self.size = 0
        self.ptr = 0
    
    def add(self, priority):
        # ptr 위치에 우선순위 저장
        idx = self.ptr + self.capacity
        self.tree[idx] = priority
        # 위로 올라가며 합 갱신
        self._propagate(idx)
        
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)
    
    def _propagate(self, idx):
        # 부모 노드로 올라가며 합 갱신
        parent = idx // 2
        while parent >= 1:
            left = parent * 2
            right = left + 1
            self.tree[parent] = self.tree[left] + self.tree[right]
            parent = parent // 2
    
    def total(self):
        return self.tree[1]  # root에 전체 합
    
    def get(self, s):
        # 우선순위 합 s에 해당하는 경험 찾기
        idx = 1
        while idx < self.capacity:
            left = idx * 2
            right = left + 1
            if s <= self.tree[left]:
                idx = left
            else:
                s -= self.tree[left]
                idx = right
        # idx가 leaf 노드
        return idx - self.capacity, self.tree[idx]

###################################
# Prioritized Replay Buffer
###################################
class PrioritizedReplayBuffer:
    def __init__(self, capacity=10000, alpha=0.6, beta_start=0.4, beta_frames=100000):
        self.capacity = capacity
        self.alpha = alpha
        self.tree = SumTree(capacity)
        self.data = []
        self.beta_start = beta_start
        self.beta_frames = beta_frames
        self.frame = 0
        self.epsilon = 1e-5
    
    def push(self, state, action, reward, next_state, done):
        # 초기 우선순위는 최대값(큰 값)으로 설정 (새로운 경험 우선 학습)
        max_priority = self.tree.tree[self.tree.capacity:].max()
        if max_priority == 0:
            max_priority = 1.0
        if len(self.data) < self.capacity:
            self.data.append((state, action, reward, next_state, done))
        else:
            self.data[self.tree.ptr] = (state, action, reward, next_state, done)
        self.tree.add(max_priority)
    
    def sample(self, batch_size):
        self.frame += 1
        beta = self.beta_start + (1.0 - self.beta_start) * min(1.0, self.frame/self.beta_frames)
        
        batch = []
        idxs = []
        segment = self.tree.total() / batch_size
        priorities = []
        
        for i in range(batch_size):
            s = random.random() * segment
            idx, p = self.tree.get(s)
            batch.append(self.data[idx])
            idxs.append(idx)
            priorities.append(p)
        
        priorities = np.array(priorities, dtype=np.float32)
        prob = priorities / self.tree.total()
        weights = (len(self.data) * prob) ** (-beta)
        weights = weights / weights.max()
        
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states, dtype=np.float32),
                np.array(actions),
                np.array(rewards, dtype=np.float32),
                np.array(next_states, dtype=np.float32),
                np.array(dones, dtype=np.float32),
                idxs, weights)
    
    def update_priorities(self, idxs, td_errors):
        # TD 오차 기반으로 우선순위 갱신
        for i, td_error in zip(idxs, td_errors):
            p = (abs(td_error) + self.epsilon) ** self.alpha
            leaf_idx = i + self.tree.capacity
            self.tree.tree[leaf_idx] = p
            self.tree._propagate(leaf_idx)

###################################
# DQN 에이전트 (PER 적용)
###################################
# 여기서는 DQNAgent를 수정하고, buffer를 PrioritizedReplayBuffer로 교체, 
# update 단계에서 td_error를 계산해 priorities 갱신
class PERDQNAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, lr=1e-3, tau=1000, batch_size=64):
        self.action_dim = action_dim
        self.gamma = gamma
        self.lr = lr
        self.tau = tau
        self.batch_size = batch_size
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.q_net = nn.Sequential(
            nn.Linear(state_dim,64), nn.ReLU(),
            nn.Linear(64,64), nn.ReLU(),
            nn.Linear(64,action_dim)
        ).to(self.device)
        
        self.target_net = nn.Sequential(
            nn.Linear(state_dim,64), nn.ReLU(),
            nn.Linear(64,64), nn.ReLU(),
            nn.Linear(64,action_dim)
        ).to(self.device)
        
        self.target_net.load_state_dict(self.q_net.state_dict())
        
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=self.lr)
        
        self.replay_buffer = PrioritizedReplayBuffer()
        
        self.total_steps = 0
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.action_dim)
        else:
            state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            q_vals = self.q_net(state_t)
            return q_vals.argmax(dim=1).item()

    def store_transition(self, state, action, reward, next_state, done):
        self.replay_buffer.push(state, action, reward, next_state, done)
        self.total_steps += 1

    def update(self):
        if len(self.replay_buffer.data) < self.batch_size:
            return
        
        # PER에서 sample 시 idx와 weights 반환
        states, actions, rewards, next_states, dones, idxs, weights = self.replay_buffer.sample(self.batch_size)
        
        states_t = torch.FloatTensor(states).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)
        rewards_t = torch.FloatTensor(rewards).to(self.device)
        next_states_t = torch.FloatTensor(next_states).to(self.device)
        dones_t = torch.FloatTensor(dones).to(self.device)
        weights_t = torch.FloatTensor(weights).to(self.device)
        
        q_values = self.q_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)
        
        with torch.no_grad():
            max_next_q = self.target_net(next_states_t).max(dim=1)[0]
            target_q = rewards_t + (1 - dones_t)*self.gamma*max_next_q
        
        td_errors = target_q - q_values
        loss = (weights_t * td_errors.pow(2)).mean()
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # td_errors 기반으로 priority 갱신
        td_errors_np = td_errors.detach().cpu().numpy()
        self.replay_buffer.update_priorities(idxs, td_errors_np)
        
        if self.total_steps % self.tau == 0:
            self.target_net.load_state_dict(self.q_net.state_dict())
        
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

def train_per_dqn(env_name="CartPole-v1", max_episodes=300):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    agent = PERDQNAgent(state_dim, action_dim)
    
    reward_history = []
    for ep in range(max_episodes):
        state = env.reset()
        total_reward = 0
        done = False
        while not done:
            action = agent.select_action(state)
            next_state, reward, done, info = env.step(action)
            agent.store_transition(state, action, reward, next_state, done)
            agent.update()
            
            state = next_state
            total_reward += reward
        
        reward_history.append(total_reward)
        if (ep+1) % 20 == 0:
            avg_reward = np.mean(reward_history[-20:])
            print(f"Episode {ep+1}, Avg Reward(last 20): {avg_reward:.2f}, Epsilon: {agent.epsilon:.2f}")

    env.close()

if __name__ == "__main__":
    train_per_dqn()

코드 해설

  • SumTree와 PrioritizedReplayBuffer 구현은 간략화했으며, 실제로는 Segment Tree 관리나 퍼포먼스 최적화를 더 정교하게 해야 합니다.
  • PrioritizedReplayBuffer는 push 시 최대 우선순위로 초기화, sample 시 우선순위 기반으로 경험 선택, 학습 후 update_priorities로 우선순위 갱신.
  • PERDQNAgent의 update 함수에서 TD 오차를 계산하고, 이를 update_priorities에 전달해 우선순위 갱신.
  • 중요도 가중치(IS weight)를 통해 샘플링 편향을 보정하기 위해 weights를 손실 계산 시 곱해줌.

마무리

우선순위 경험 리플레이(PER)는 강화학습 알고리즘의 데이터 효율을 높여, 더 빠른 수렴과 안정적인 성능 개선을 돕습니다. 이 글에서는 간단한 예제를 통해 기본 원리를 설명했으나, 실제로는 Segment Tree 관리나 매개변수(α, β) 조정 등 세부 구현과 튜닝이 필요합니다.

다음 글에서는 이러한 기법들을 종합하여 RL 성능을 극대화하거나, 다른 유형의 RL 알고리즘(정책기반, 액터-크리틱 등)으로 범위를 확장하는 방법을 살펴볼 수 있습니다.

반응형