[PyTorch로 시작하는 강화학습 입문] 3편: DQN(Deep Q-Network) 기초 구현 – 경험 리플레이와 타겟 네트워크

지난 글에서 가치 기반 접근과 Q함수 근사를 위한 PyTorch 신경망 구조를 마련했다면, 이번 글에서는 이를 실제로 학습시키기 위한 대표적인 딥 강화학습 알고리즘인 DQN(Deep Q-Network)의 기본 골격을 구현해 봅니다. 여기서는 경험 리플레이(Replay Buffer)와 ε-탐욕적(epsilon-greedy) 정책, 그리고 타겟 네트워크(Target Network) 개념을 소개하고, CartPole 환경에서 DQN을 간단히 훈련시키는 예제를 통해 Q함수를 실제로 업데이트하는 과정을 살펴보겠습니다.

 

강화학습에서 Q함수를 딥뉴럴넷으로 근사하는 것은 확장성 측면에서 유용하지만, 단순히 Q-learning을 신경망에 직접 대입하는 것만으로는 학습이 불안정합니다. 경험(transition)을 순서대로 학습하면 샘플의 상관관계가 높아져 최적화가 불안정해지고, 목표(target) 값도 매 스텝 변동되어 학습이 어렵습니다. 이를 해결하는 핵심 아이디어가 바로 DQN(Deep Q-Network) 알고리즘입니다.

DQN의 주요 요소:

  1. 경험 리플레이(Experience Replay):
    과거의 경험 (s, a, r, s', done)을 메모리에 저장한 뒤 무작위로 샘플링해 학습에 활용.
    이렇게 하면 데이터 상관성이 줄고, 안정된 학습이 가능해집니다.
  2. 타겟 네트워크(Target Network):
    Q함수를 근사하는 메인 네트워크와 별도로, 일정 주기마다 메인 네트워크 파라미터를 복사하는 타겟 네트워크를 둡니다.
    타겟 네트워크는 업데이트 간격 사이에는 고정되어 있어, 목표값(target value)이 덜 흔들려 안정화에 도움을 줍니다.
  3. ε-탐욕적(Epsilon-Greedy) 정책:
    정책 개선 시 탐색(exploration)을 위해 ε의 확률로 랜덤 행동을 취하고, 나머지는 현재 Q값이 최대인 행동을 선택합니다.
    학습이 진행됨에 따라 ε을 감소시켜 점차 탐색을 줄이고 활용(exploitation)을 늘립니다.

이번 글에서는 다음을 수행합니다.

  • 간단한 Replay Buffer 클래스 구현
  • QNetwork를 사용한 DQN 구조 구현 (메인 네트워크, 타겟 네트워크)
  • CartPole 환경에서 DQN 학습 루프 구현
  • 일정 에피소드마다 평균 보상 출력해 성능 개선 확인

이제 이론적 배경은 최소한으로 하고, 실제 코드 예제를 통해 DQN의 뼈대를 잡아보겠습니다.

추가 참고자료(이론):

코드 흐름 및 구성

코드는 하나의 파일(dqn_cartpole.py)로 실행 가능합니다. 전체 코드는 아래에 제시하고, 이후 중요한 부분을 풀이하겠습니다.

  • QNetwork: 상태를 입력받아 각 행동에 대한 Q값을 출력하는 MLP 구조
  • ReplayBuffer: 경험(transition)을 저장하고 무작위 샘플링하는 클래스
  • DQNAgent:
    • 메인 Q네트워크와 타겟 네트워크 관리
    • ε-탐욕적 정책으로 행동 선택
    • 경험을 ReplayBuffer에 저장
    • 일정 스텝마다 미니배치를 샘플링해 Q 네트워크 업데이트 (Q-learning 타겟 활용)
    • 일정 스텝마다 타겟 네트워크 동기화
    • ε 값 감소로 점진적으로 탐색 감소
  • train_dqn 함수:
    • 환경을 초기화하고 에피소드를 진행
    • 매 에피소드마다 DQNAgent를 이용해 상태-행동-보상을 수집하고 학습
    • 20 에피소드마다 평균 보상을 출력해 학습 경향 관찰

예제 코드 (self-contained, 주석 풍부)

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
from collections import deque

###################################
# Q-Network 정의
###################################
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=64):
        super(QNetwork, self).__init__()
        # 간단한 3층 MLP: state_dim -> hidden -> hidden -> action_dim
        self.fc1 = nn.Linear(state_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, action_dim)
    
    def forward(self, x):
        # 입력 x: (batch_size, state_dim)
        # 출력: (batch_size, action_dim)의 Q값 벡터
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q_values = self.fc3(x)
        return q_values

###################################
# 경험 리플레이 버퍼
###################################
class ReplayBuffer:
    def __init__(self, capacity=10000):
        # 고정 길이의 덱(deque)에 transition을 저장
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        # 하나의 transition 추가
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size=64):
        # 버퍼에서 무작위로 batch_size개 샘플 추출
        batch = random.sample(self.buffer, batch_size)
        # 각각을 별도 배열로 묶어 반환
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states, dtype=np.float32),
                np.array(actions),
                np.array(rewards, dtype=np.float32),
                np.array(next_states, dtype=np.float32),
                np.array(dones, dtype=np.float32))
    
    def __len__(self):
        return len(self.buffer)

###################################
# DQN 에이전트
###################################
class DQNAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, lr=1e-3, tau=1000, batch_size=64):
        """
        state_dim: 상태 차원
        action_dim: 행동 개수
        gamma: 할인율
        lr: 학습률
        tau: 타겟 네트워크 업데이트 주기(스텝 단위)
        batch_size: 미니배치 크기
        """
        self.action_dim = action_dim
        self.gamma = gamma
        self.lr = lr
        self.tau = tau
        self.batch_size = batch_size
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # 메인 Q네트워크와 타겟 Q네트워크 생성
        self.q_net = QNetwork(state_dim, action_dim).to(self.device)
        self.target_net = QNetwork(state_dim, action_dim).to(self.device)
        
        # 타겟 네트워크 초기화: 메인 네트워크 파라미터 복사
        self.target_net.load_state_dict(self.q_net.state_dict())
        
        # Adam 옵티마이저
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=self.lr)
        
        # 경험 리플레이 버퍼
        self.replay_buffer = ReplayBuffer()
        
        self.total_steps = 0  # 학습 과정에서의 총 스텝 수
        self.epsilon = 1.0    # 초기 ε 값 (탐색 많이)
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01

    def select_action(self, state):
        """ε-탐욕적 정책으로 행동 선택"""
        if random.random() < self.epsilon:
            # ε 확률로 랜덤 행동 선택
            return random.randrange(self.action_dim)
        else:
            # Q네트워크를 이용해 가장 가치 높은 행동 선택
            state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            q_vals = self.q_net(state_t)  # shape: (1, action_dim)
            action = q_vals.argmax(dim=1).item()
            return action

    def store_transition(self, state, action, reward, next_state, done):
        """환경 상호작용 결과를 리플레이 버퍼에 저장"""
        self.replay_buffer.push(state, action, reward, next_state, done)
        self.total_steps += 1

    def update(self):
        """리플레이 버퍼 샘플링 후 Q네트워크 파라미터 업데이트"""
        if len(self.replay_buffer) < self.batch_size:
            # 버퍼에 충분한 데이터가 없으면 업데이트 건너뛰기
            return
        # 샘플링
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        
        # 파이토치 텐서 변환
        states_t = torch.FloatTensor(states).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)
        rewards_t = torch.FloatTensor(rewards).to(self.device)
        next_states_t = torch.FloatTensor(next_states).to(self.device)
        dones_t = torch.FloatTensor(dones).to(self.device)
        
        # 현재 Q(s,a)
        q_values = self.q_net(states_t)             # (batch_size, action_dim)
        q_values = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)
        # gather 사용: 선택한 actions에 해당하는 Q값만 추출
        
        # 타겟 Q값 계산: r + γ * max Q(s',a')
        with torch.no_grad():
            next_q = self.target_net(next_states_t) # 타겟 네트워크로 s' Q값 계산
            max_next_q = next_q.max(dim=1)[0]       # 각 샘플별 max Q(s',a')
            target_q = rewards_t + (1 - dones_t) * self.gamma * max_next_q
        
        # MSE 손실
        loss = F.mse_loss(q_values, target_q)
        
        # 역전파로 메인 네트워크 업데이트
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 일정 스텝마다 타겟 네트워크 동기화
        if self.total_steps % self.tau == 0:
            self.target_net.load_state_dict(self.q_net.state_dict())
        
        # ε 감소 (탐색 축소, 활용 증가)
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

def train_dqn(env_name="CartPole-v1", max_episodes=300):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    agent = DQNAgent(state_dim, action_dim)
    
    reward_history = []
    for ep in range(max_episodes):
        state = env.reset()
        total_reward = 0
        done = False
        while not done:
            # ε-탐욕적 행동 선택
            action = agent.select_action(state)
            next_state, reward, done, info = env.step(action)
            
            # 경험 저장 및 업데이트 호출
            agent.store_transition(state, action, reward, next_state, done)
            agent.update()
            
            state = next_state
            total_reward += reward
        
        reward_history.append(total_reward)
        
        # 20 에피소드마다 평균 보상 출력
        if (ep+1) % 20 == 0:
            avg_reward = np.mean(reward_history[-20:])
            print(f"Episode {ep+1}, Avg Reward(last 20): {avg_reward:.2f}, Epsilon: {agent.epsilon:.2f}")
    
    env.close()

if __name__ == "__main__":
    train_dqn()

코드 해설

  • QNetwork: 상태 → Q값을 출력하는 간단한 신경망. CartPole 예제로 state_dim=4, action_dim=2일 경우, 입력 4차원, 출력 2차원 Q값 벡터.
  • ReplayBuffer: (s, a, r, s', done) 형식의 경험을 저장. sample() 함수로 무작위 배치 추출. 이를 통해 시계열 상관성을 줄이고, 안정적 미니배치 학습 가능.
  • DQNAgent:
    • select_action: ε-탐욕적 정책. 초기엔 ε=1.0으로 거의 랜덤 행동. 학습 진행하며 ε 감소 → 점차 최적 행동 선택 비율 증가.
    • store_transition: 환경에서 얻은 경험을 버퍼에 적재.
    • update: 버퍼에서 미니배치 뽑아 Q-learning 타겟 계산 후 Q네트워크 파라미터 업데이트. 일정 스텝마다 타겟 네트워크 동기화로 안정성 확보.
    • ε 감소 로직으로 학습이 진행될수록 탐색 줄이기.
  • train_dqn:
    • 여러 에피소드를 돌며 DQNAgent를 통해 강화학습 진행.
    • 20 에피소드마다 평균 보상을 출력해 성능 개선 추이 확인. 수십 에피소드 지나면 평균 보상이 상승하는 경향을 볼 수 있음.

학습 결과

이 코드를 실행하면 초기에는 평균 보상이 낮지만, 시간이 지날수록 CartPole을 더 오래 유지하는 행동을 배우며 평균 보상이 높아집니다. 이로써 랜덤 행동과 달리 학습된 Q함수를 통해 정책을 개선할 수 있음을 확인할 수 있습니다.

물론 이 예제는 기본 형태의 DQN으로, 최적화나 파라미터 조정이 충분치 않아 완벽한 결과는 아니지만, DQN 알고리즘이 작동하는 기초 메커니즘을 직접 경험해 보는 데 의의가 있습니다.

마무리

이번 글에서는 경험 리플레이, 타겟 네트워크, ε-탐욕적 정책이라는 DQN 핵심 개념을 구현하여, CartPole 환경에서 실제로 Q함수를 학습하는 코드를 살펴보았습니다. 이를 통해 강화학습 알고리즘이 단순 랜덤 행동에서 벗어나, 경험을 적극적으로 활용하며 성능을 향상하는 과정을 체감할 수 있습니다.

다음 글에서는 DQN을 더 개선하거나, 다른 RL 알고리즘 변형(Double DQN, Dueling DQN, Prioritized Replay 등)을 도입해 성능 향상 방법을 논의할 수 있습니다. 또 다른 환경에 적용해보면서 RL 알고리즘 구현 경험을 확장할 수 있습니다.

반응형