[PyTorch로 시작하는 강화학습 입문] 5편: Dueling DQN 구현으로 Q함수 구조 개선하기

DQN 계열 알고리즘의 핵심은 상태-행동 가치(Q-value)를 효과적으로 추정하는 것입니다. 지금까지의 네트워크는 상태를 입력받아 각 행동에 대한 Q값을 직접 출력하는 구조를 사용했습니다. 그러나 모든 행동에 대한 Q값을 별도로 추정하는 것은 비효율적일 수 있습니다. 상태 자체의 "가치(Value)"와, 그 상태에서 특정 행동을 선택함으로써 추가로 얻을 수 있는 "우위(Advantage)"를 분리하면, 공통적인 상태 가치를 학습하면서도 행동별 차이를 더 효율적으로 포착할 수 있습니다.

Dueling DQN(Dueling Network Architecture for Deep Reinforcement Learning)에서는 Q(s,a)를 다음과 같이 분해합니다.

  • Q(s,a) = V(s) + A(s,a) - 평균(A(s,a))
    여기서 V(s)는 상태의 가치, A(s,a)는 해당 상태에서 특정 행동의 상대적 우위(Advantage)입니다. 평균(A(s,a))을 빼는 이유는 Q값을 정규화하여, Advantage가 행동 사이의 상대적 차이를 잘 드러내도록 하는 목적이 있습니다.

장점:

  • 상태 가치(V)를 직접 추정하기 때문에, 어떤 상태의 가치가 명확해지고, 행동 차이가 별로 없는 경우에도 상태 자체의 유용성을 빨리 파악할 수 있습니다.
  • Advantage 구조를 통해 상태가 행동에 큰 영향을 받지 않거나 특정 행동만 유리한 경우를 효율적으로 표현할 수 있습니다.

추가 참고자료:

Dueling DQN 네트워크 구조

기존 QNetwork 대신, 상태를 입력받아 다음 두 개의 분기를 갖는 신경망을 만듭니다.

  1. Value Stream(V): 상태에서 상태 가치 V(s)를 예측하는 경로
  2. Advantage Stream(A): 상태에서 각 행동의 Advantage A(s,a)를 예측하는 경로

마지막에 V(s)와 A(s,a)를 합쳐 Q(s,a)를 산출합니다. 구현 상으로는 중간까지는 공용 레이어(Feature extractor)를 통과한 뒤, 두 개의 FC 레이어를 통해 Value와 Advantage를 구분합니다.

코드 예제 (Dueling DQN 적용)

아래 코드는 Dueling DQN을 위해 QNetwork를 DuelingQNetwork로 변경한 예제입니다. 나머지 구조는 이전 DQN과 거의 동일하며, Double DQN을 사용하려면 이전 글의 아이디어를 혼합해 적용할 수도 있습니다. 여기서는 기본 DQN 코드에 Dueling 구조만 적용하는 예로 듭니다.

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
from collections import deque

###################################
# Dueling Q-Network 정의
###################################
class DuelingQNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=64):
        super(DuelingQNetwork, self).__init__()
        # 공통 피쳐 추출 레이어
        self.fc1 = nn.Linear(state_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)

        # 상태 가치 V(s) 추출하는 스트림
        self.value_fc = nn.Linear(hidden_size, 1)
        # Advantage A(s,a) 추출하는 스트림
        self.adv_fc = nn.Linear(hidden_size, action_dim)
    
    def forward(self, x):
        # 공통 처리
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        # Value와 Advantage 계산
        value = self.value_fc(x)            # (batch_size, 1)
        advantage = self.adv_fc(x)          # (batch_size, action_dim)
        
        # Q(s,a) = V(s) + A(s,a) - mean(A(s,a))
        advantage_mean = advantage.mean(dim=1, keepdim=True)  # 행동 차원 평균
        q_values = value + (advantage - advantage_mean)
        return q_values

###################################
# 경험 리플레이 버퍼
###################################
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size=64):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states, dtype=np.float32),
                np.array(actions),
                np.array(rewards, dtype=np.float32),
                np.array(next_states, dtype=np.float32),
                np.array(dones, dtype=np.float32))
    
    def __len__(self):
        return len(self.buffer)

###################################
# Dueling DQN 에이전트
###################################
class DuelingDQNAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, lr=1e-3, tau=1000, batch_size=64):
        self.action_dim = action_dim
        self.gamma = gamma
        self.lr = lr
        self.tau = tau
        self.batch_size = batch_size
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Dueling QNetwork 사용
        self.q_net = DuelingQNetwork(state_dim, action_dim).to(self.device)
        self.target_net = DuelingQNetwork(state_dim, action_dim).to(self.device)
        self.target_net.load_state_dict(self.q_net.state_dict())
        
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=self.lr)
        
        self.replay_buffer = ReplayBuffer()
        
        self.total_steps = 0
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.action_dim)
        else:
            state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            q_vals = self.q_net(state_t)
            return q_vals.argmax(dim=1).item()

    def store_transition(self, state, action, reward, next_state, done):
        self.replay_buffer.push(state, action, reward, next_state, done)
        self.total_steps += 1

    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        
        states_t = torch.FloatTensor(states).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)
        rewards_t = torch.FloatTensor(rewards).to(self.device)
        next_states_t = torch.FloatTensor(next_states).to(self.device)
        dones_t = torch.FloatTensor(dones).to(self.device)
        
        # Q(s,a)
        q_values = self.q_net(states_t)
        q_values = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)
        
        # 타겟 Q값: r + γ * max Q(s',a')
        with torch.no_grad():
            next_q = self.target_net(next_states_t)
            max_next_q = next_q.max(dim=1)[0]
            target_q = rewards_t + (1 - dones_t) * self.gamma * max_next_q
        
        loss = F.mse_loss(q_values, target_q)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 일정 스텝마다 타겟 네트워크 업데이트
        if self.total_steps % self.tau == 0:
            self.target_net.load_state_dict(self.q_net.state_dict())
        
        # ε 감소
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

def train_dueling_dqn(env_name="CartPole-v1", max_episodes=300):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    agent = DuelingDQNAgent(state_dim, action_dim)
    
    reward_history = []
    for ep in range(max_episodes):
        state = env.reset()
        total_reward = 0
        done = False
        while not done:
            action = agent.select_action(state)
            next_state, reward, done, info = env.step(action)
            agent.store_transition(state, action, reward, next_state, done)
            agent.update()
            
            state = next_state
            total_reward += reward
        
        reward_history.append(total_reward)
        if (ep+1) % 20 == 0:
            avg_reward = np.mean(reward_history[-20:])
            print(f"Episode {ep+1}, Avg Reward(last 20): {avg_reward:.2f}, Epsilon: {agent.epsilon:.2f}")

    env.close()

if __name__ == "__main__":
    train_dueling_dqn()

코드 해설

  • DuelingQNetwork 클래스:
    • 상태 입력 후 공통 레이어(두 개의 FC 레이어)로 특징 추출
    • value_fc로부터 V(s) 추출, adv_fc로부터 A(s,a) 추출
    • Q(s,a) = V(s) + A(s,a) - 평균(A(s,a)) 공식으로 최종 Q값 계산
  • 나머지 DQN 로직(DQNAgent, ReplayBuffer, 훈련 루프)는 기존 DQN과 동일
  • Dueling DQN은 성능 개선에 기여할 수 있지만, 모든 환경에서 즉각적으로 큰 개선을 보장하는 것은 아니며, 하이퍼파라미터나 다른 기법들과 조합 시 더욱 효과적

Dueling DQN의 의의

  • 상태가 분명히 유용하지만 어떤 행동을 취하든 큰 차이가 없을 때, Dueling 구조는 V(s)를 안정적으로 학습함으로써 나중에 환경 변화나 행동 차이가 발생했을 때 더 빠르게 정책 개선을 유도할 수 있습니다.
  • Advantage를 통해 행동 간 상대적 차이를 분명히 표현하면서도, 불필요하게 모든 행동에 대한 Q값을 독립적으로 학습하는 부하를 줄여줍니다.

마무리

이번 글에서는 Dueling DQN 개념을 소개하고, 기존 DQN 코드에 작은 수정만으로 이를 구현하는 방법을 살펴보았습니다. Dueling DQN은 DQN 계열 알고리즘의 성능과 효율을 향상시키는 대표적인 변형 중 하나이며, Double DQN, Prioritized Replay 등 다른 기법들과 결합하여 더욱 강력한 강화학습 에이전트를 만들 수 있습니다.

다음 글에서는 우선순위 경험 리플레이나 다른 고급 기법을 추가적으로 살펴보거나, 더 복잡한 환경에 도전하면서 RL 알고리즘을 개선·응용하는 방향으로 나아갈 수 있습니다.

반응형