[PyTorch로 시작하는 강화학습 입문] 4편: DQN 개선하기 – Double DQN 구현 및 추가 변형 소개

 

이전 글에서 DQN을 구현해 CartPole 환경을 학습시켜보았습니다. DQN은 간단하고 효과적이지만, 여전히 다음과 같은 문제가 남아 있습니다.

  • Q값의 과추정(Overestimation): DQN은 최대 Q값을 직접 사용하기 때문에, 노이즈나 학습 초기 불안정으로 인해 실제보다 높은 Q값을 선택하는 경향이 있습니다. 이로 인해 정책이 왜곡될 수 있습니다.
  • 데이터 효율성과 안정성 문제: 경험 리플레이를 사용하지만, 모든 transition이 동일한 확률로 샘플링됩니다. 또, Q값 계산 시 행동에 따른 Q분포를 좀 더 효율적으로 학습할 수 있는 구조적 개선도 가능할 것입니다.

이러한 문제를 완화하기 위해 다양한 DQN 변형 알고리즘이 제안되었습니다. 그중 대표적인 두 가지를 소개하겠습니다.

  1. Double DQN (Double Q-learning):
    DQN의 Q값 과추정 문제를 줄이기 위해, 행동 선택과 Q값 평가를 분리합니다.
    기본 아이디어:
    • 행동 선택은 메인 네트워크로 하고,
    • 해당 행동의 Q값은 타겟 네트워크로 평가
      이렇게 하면 Q값 계산 과정에서 단일 네트워크를 중복 활용하는 편향을 줄일 수 있습니다.
  2. Dueling DQN:
    Q함수를 Value와 Advantage로 분해하는 듀얼 구조를 도입해, 상태 자체의 가치와 행동별 가치를 명확히 구분합니다. 이를 통해 학습 효율을 높일 수 있습니다.

또한, 우선순위 경험 리플레이(Prioritized Experience Replay), Noisy Net, Rainbow DQN 등 다양한 기법들이 제안되어 RL 성능을 더욱 개선할 수 있습니다. 이번 글에서는 Double DQN에 초점을 맞추고, DQN 코드를 어떻게 수정하면 되는지 살펴보겠습니다.

추가 참고자료:

Double DQN 아이디어

기존 DQN에서 타겟 Q값을 계산할 때, max_{a'} Q_target(s', a')를 사용했습니다. 이때 Q값을 평가하는 타겟 네트워크와, 그 중 최대 Q를 주는 행동을 선택하는 과정이 모두 타겟 네트워크에 의존합니다. Double DQN은 다음과 같이 변경합니다.

  1. 행동 선택: 메인 네트워크 Q_net를 이용해 a* = argmax_{a'} Q_net(s', a')를 찾습니다.
  2. Q값 평가: 타겟 네트워크 Q_target를 이용해 Q_target(s', a*)을 계산합니다.

이렇게 함으로써, 행동 선택과 평가에 다른 네트워크(혹은 동일 네트워크의 서로 다른 파라미터 집합)를 사용하게 되어, Q값 과추정을 완화할 수 있습니다.

코드 예제 (Double DQN 적용)

아래 코드는 이전 글의 DQN 코드(dqn_cartpole.py)를 기반으로 하며, 타겟 Q값 계산 부분만 Double DQN 방식으로 수정합니다. 큰 틀은 DQN과 동일합니다. 변경점에 주석을 자세히 달았습니다.

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
from collections import deque

###################################
# Q-Network 정의 (동일)
###################################
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=64):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, action_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q_values = self.fc3(x)
        return q_values

###################################
# 경험 리플레이 버퍼 (동일)
###################################
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size=64):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states, dtype=np.float32),
                np.array(actions),
                np.array(rewards, dtype=np.float32),
                np.array(next_states, dtype=np.float32),
                np.array(dones, dtype=np.float32))
    
    def __len__(self):
        return len(self.buffer)

###################################
# Double DQN Agent
###################################
class DoubleDQNAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, lr=1e-3, tau=1000, batch_size=64):
        self.action_dim = action_dim
        self.gamma = gamma
        self.lr = lr
        self.tau = tau
        self.batch_size = batch_size
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.q_net = QNetwork(state_dim, action_dim).to(self.device)
        self.target_net = QNetwork(state_dim, action_dim).to(self.device)
        self.target_net.load_state_dict(self.q_net.state_dict())
        
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=self.lr)
        
        self.replay_buffer = ReplayBuffer()
        
        self.total_steps = 0
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.action_dim)
        else:
            state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            q_vals = self.q_net(state_t)
            return q_vals.argmax(dim=1).item()

    def store_transition(self, state, action, reward, next_state, done):
        self.replay_buffer.push(state, action, reward, next_state, done)
        self.total_steps += 1

    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        
        states_t = torch.FloatTensor(states).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)
        rewards_t = torch.FloatTensor(rewards).to(self.device)
        next_states_t = torch.FloatTensor(next_states).to(self.device)
        dones_t = torch.FloatTensor(dones).to(self.device)
        
        # 현재 Q(s,a)
        q_values = self.q_net(states_t)
        q_values = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)
        
        # Double DQN 타겟 Q값 계산
        # 1) 메인 네트워크로 다음 상태에서 행동 선택
        next_q_values = self.q_net(next_states_t)              # 메인 네트워크로 s'에서의 Q값
        best_actions = next_q_values.argmax(dim=1)             # a* = argmax_a Q_net(s',a')
        
        # 2) 타겟 네트워크로 해당 행동의 Q값 평가
        with torch.no_grad():
            next_q_target = self.target_net(next_states_t)     # 타겟 네트워크로 s' Q값
            # a* 행동에 해당하는 Q값만 추출
            target_q_a_star = next_q_target.gather(1, best_actions.unsqueeze(1)).squeeze(1)
            
            # 타겟: r + γ * Q_target(s', a*)
            target_q = rewards_t + (1 - dones_t) * self.gamma * target_q_a_star
        
        loss = F.mse_loss(q_values, target_q)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 일정 스텝마다 타겟 네트워크 파라미터 동기화
        if self.total_steps % self.tau == 0:
            self.target_net.load_state_dict(self.q_net.state_dict())
        
        # ε 감소
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

def train_double_dqn(env_name="CartPole-v1", max_episodes=300):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    agent = DoubleDQNAgent(state_dim, action_dim)
    
    reward_history = []
    for ep in range(max_episodes):
        state = env.reset()
        total_reward = 0
        done = False
        while not done:
            action = agent.select_action(state)
            next_state, reward, done, info = env.step(action)
            agent.store_transition(state, action, reward, next_state, done)
            agent.update()
            
            state = next_state
            total_reward += reward
        
        reward_history.append(total_reward)
        if (ep+1) % 20 == 0:
            avg_reward = np.mean(reward_history[-20:])
            print(f"Episode {ep+1}, Avg Reward(last 20): {avg_reward:.2f}, Epsilon: {agent.epsilon:.2f}")
    
    env.close()

if __name__ == "__main__":
    train_double_dqn()

코드 해설

  • 대부분의 구조는 이전 DQN과 동일하지만, 타겟 Q값 계산 부분이 변경되었습니다.
  • Double DQN 핵심 변경점:
    • best_actions = next_q_values.argmax(dim=1)
      여기서 next_q_values는 메인 네트워크로 계산한 s' 상태에서의 Q값입니다. 이를 통해 행동 a*를 선택합니다.
    • 그 후 target_q_a_star = next_q_target.gather(1, best_actions.unsqueeze(1))로 타겟 네트워크에서 이 행동 a*의 Q값을 얻습니다.
    • 이렇게 하면 행동 선택에 메인 네트워크를 사용하고, Q값 평가는 타겟 네트워크를 사용하여 과추정을 줄입니다.

Dueling DQN, Prioritized Replay 등 추가 변형

  • Dueling DQN: Q(s,a) = V(s) + A(s,a) 형태로 분해한 네트워크 구조를 사용합니다. 상태별 기본 가치 V(s)와 각 행동별 추가 Advantage A(s,a)를 분리하여, 행동 선택에 더 효율적인 표현을 제공합니다.
  • Prioritized Experience Replay: 모든 경험을 동일하게 취급하지 않고, TD오차가 큰 경험(학습에 중요한 경험)에 더 높은 샘플링 확률을 부여해 데이터 효율을 높입니다.
  • Noisy Networks, Rainbow DQN: 다양한 기법(분포형 Q학습, n-step Return, C51, QR-DQN, ...)을 조합한 Rainbow DQN 등 더 복잡한 변형들이 있습니다.

이러한 변형들을 추가적으로 구현함으로써 RL 성능과 안정성을 더 높일 수 있습니다.

마무리

이번 글에서는 Double DQN 기법을 소개하고, 기존 DQN 코드에 간단히 적용하는 방법을 살펴보았습니다. Double DQN은 Q값 과추정을 완화해 더 안정적인 학습을 가능하게 하며, 이를 토대로 듀얼링 아키텍처나 우선순위 리플레이 등 다른 기법들과 결합하면 더욱 강력한 RL 에이전트를 만들 수 있습니다.

다음 글에서는 이러한 변형들 중 하나를 더 구현해보거나, 더 복잡한 환경으로 확장해보는 과정을 거치며 강화학습 실전 감각을 키워갈 수 있습니다.

반응형