지난 글에서 가치 기반 접근과 Q함수 근사를 위한 PyTorch 신경망 구조를 마련했다면, 이번 글에서는 이를 실제로 학습시키기 위한 대표적인 딥 강화학습 알고리즘인 DQN(Deep Q-Network)의 기본 골격을 구현해 봅니다. 여기서는 경험 리플레이(Replay Buffer)와 ε-탐욕적(epsilon-greedy) 정책, 그리고 타겟 네트워크(Target Network) 개념을 소개하고, CartPole 환경에서 DQN을 간단히 훈련시키는 예제를 통해 Q함수를 실제로 업데이트하는 과정을 살펴보겠습니다.
강화학습에서 Q함수를 딥뉴럴넷으로 근사하는 것은 확장성 측면에서 유용하지만, 단순히 Q-learning을 신경망에 직접 대입하는 것만으로는 학습이 불안정합니다. 경험(transition)을 순서대로 학습하면 샘플의 상관관계가 높아져 최적화가 불안정해지고, 목표(target) 값도 매 스텝 변동되어 학습이 어렵습니다. 이를 해결하는 핵심 아이디어가 바로 DQN(Deep Q-Network) 알고리즘입니다.
DQN의 주요 요소:
- 경험 리플레이(Experience Replay):
과거의 경험 (s, a, r, s', done)을 메모리에 저장한 뒤 무작위로 샘플링해 학습에 활용.
이렇게 하면 데이터 상관성이 줄고, 안정된 학습이 가능해집니다. - 타겟 네트워크(Target Network):
Q함수를 근사하는 메인 네트워크와 별도로, 일정 주기마다 메인 네트워크 파라미터를 복사하는 타겟 네트워크를 둡니다.
타겟 네트워크는 업데이트 간격 사이에는 고정되어 있어, 목표값(target value)이 덜 흔들려 안정화에 도움을 줍니다. - ε-탐욕적(Epsilon-Greedy) 정책:
정책 개선 시 탐색(exploration)을 위해 ε의 확률로 랜덤 행동을 취하고, 나머지는 현재 Q값이 최대인 행동을 선택합니다.
학습이 진행됨에 따라 ε을 감소시켜 점차 탐색을 줄이고 활용(exploitation)을 늘립니다.
이번 글에서는 다음을 수행합니다.
- 간단한 Replay Buffer 클래스 구현
- QNetwork를 사용한 DQN 구조 구현 (메인 네트워크, 타겟 네트워크)
- CartPole 환경에서 DQN 학습 루프 구현
- 일정 에피소드마다 평균 보상 출력해 성능 개선 확인
이제 이론적 배경은 최소한으로 하고, 실제 코드 예제를 통해 DQN의 뼈대를 잡아보겠습니다.
추가 참고자료(이론):
- Nature DQN 논문 (Mnih et al., 2015): https://www.nature.com/articles/nature14236
- 유튜브: David Silver's RL 강의, DQN 관련 부분
코드 흐름 및 구성
코드는 하나의 파일(dqn_cartpole.py)로 실행 가능합니다. 전체 코드는 아래에 제시하고, 이후 중요한 부분을 풀이하겠습니다.
- QNetwork: 상태를 입력받아 각 행동에 대한 Q값을 출력하는 MLP 구조
- ReplayBuffer: 경험(transition)을 저장하고 무작위 샘플링하는 클래스
- DQNAgent:
- 메인 Q네트워크와 타겟 네트워크 관리
- ε-탐욕적 정책으로 행동 선택
- 경험을 ReplayBuffer에 저장
- 일정 스텝마다 미니배치를 샘플링해 Q 네트워크 업데이트 (Q-learning 타겟 활용)
- 일정 스텝마다 타겟 네트워크 동기화
- ε 값 감소로 점진적으로 탐색 감소
- train_dqn 함수:
- 환경을 초기화하고 에피소드를 진행
- 매 에피소드마다 DQNAgent를 이용해 상태-행동-보상을 수집하고 학습
- 20 에피소드마다 평균 보상을 출력해 학습 경향 관찰
예제 코드 (self-contained, 주석 풍부)
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
from collections import deque
###################################
# Q-Network 정의
###################################
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_size=64):
super(QNetwork, self).__init__()
# 간단한 3층 MLP: state_dim -> hidden -> hidden -> action_dim
self.fc1 = nn.Linear(state_dim, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, action_dim)
def forward(self, x):
# 입력 x: (batch_size, state_dim)
# 출력: (batch_size, action_dim)의 Q값 벡터
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
q_values = self.fc3(x)
return q_values
###################################
# 경험 리플레이 버퍼
###################################
class ReplayBuffer:
def __init__(self, capacity=10000):
# 고정 길이의 덱(deque)에 transition을 저장
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
# 하나의 transition 추가
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size=64):
# 버퍼에서 무작위로 batch_size개 샘플 추출
batch = random.sample(self.buffer, batch_size)
# 각각을 별도 배열로 묶어 반환
states, actions, rewards, next_states, dones = zip(*batch)
return (np.array(states, dtype=np.float32),
np.array(actions),
np.array(rewards, dtype=np.float32),
np.array(next_states, dtype=np.float32),
np.array(dones, dtype=np.float32))
def __len__(self):
return len(self.buffer)
###################################
# DQN 에이전트
###################################
class DQNAgent:
def __init__(self, state_dim, action_dim, gamma=0.99, lr=1e-3, tau=1000, batch_size=64):
"""
state_dim: 상태 차원
action_dim: 행동 개수
gamma: 할인율
lr: 학습률
tau: 타겟 네트워크 업데이트 주기(스텝 단위)
batch_size: 미니배치 크기
"""
self.action_dim = action_dim
self.gamma = gamma
self.lr = lr
self.tau = tau
self.batch_size = batch_size
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 메인 Q네트워크와 타겟 Q네트워크 생성
self.q_net = QNetwork(state_dim, action_dim).to(self.device)
self.target_net = QNetwork(state_dim, action_dim).to(self.device)
# 타겟 네트워크 초기화: 메인 네트워크 파라미터 복사
self.target_net.load_state_dict(self.q_net.state_dict())
# Adam 옵티마이저
self.optimizer = optim.Adam(self.q_net.parameters(), lr=self.lr)
# 경험 리플레이 버퍼
self.replay_buffer = ReplayBuffer()
self.total_steps = 0 # 학습 과정에서의 총 스텝 수
self.epsilon = 1.0 # 초기 ε 값 (탐색 많이)
self.epsilon_decay = 0.995
self.epsilon_min = 0.01
def select_action(self, state):
"""ε-탐욕적 정책으로 행동 선택"""
if random.random() < self.epsilon:
# ε 확률로 랜덤 행동 선택
return random.randrange(self.action_dim)
else:
# Q네트워크를 이용해 가장 가치 높은 행동 선택
state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_vals = self.q_net(state_t) # shape: (1, action_dim)
action = q_vals.argmax(dim=1).item()
return action
def store_transition(self, state, action, reward, next_state, done):
"""환경 상호작용 결과를 리플레이 버퍼에 저장"""
self.replay_buffer.push(state, action, reward, next_state, done)
self.total_steps += 1
def update(self):
"""리플레이 버퍼 샘플링 후 Q네트워크 파라미터 업데이트"""
if len(self.replay_buffer) < self.batch_size:
# 버퍼에 충분한 데이터가 없으면 업데이트 건너뛰기
return
# 샘플링
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
# 파이토치 텐서 변환
states_t = torch.FloatTensor(states).to(self.device)
actions_t = torch.LongTensor(actions).to(self.device)
rewards_t = torch.FloatTensor(rewards).to(self.device)
next_states_t = torch.FloatTensor(next_states).to(self.device)
dones_t = torch.FloatTensor(dones).to(self.device)
# 현재 Q(s,a)
q_values = self.q_net(states_t) # (batch_size, action_dim)
q_values = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)
# gather 사용: 선택한 actions에 해당하는 Q값만 추출
# 타겟 Q값 계산: r + γ * max Q(s',a')
with torch.no_grad():
next_q = self.target_net(next_states_t) # 타겟 네트워크로 s' Q값 계산
max_next_q = next_q.max(dim=1)[0] # 각 샘플별 max Q(s',a')
target_q = rewards_t + (1 - dones_t) * self.gamma * max_next_q
# MSE 손실
loss = F.mse_loss(q_values, target_q)
# 역전파로 메인 네트워크 업데이트
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 일정 스텝마다 타겟 네트워크 동기화
if self.total_steps % self.tau == 0:
self.target_net.load_state_dict(self.q_net.state_dict())
# ε 감소 (탐색 축소, 활용 증가)
self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)
def train_dqn(env_name="CartPole-v1", max_episodes=300):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)
reward_history = []
for ep in range(max_episodes):
state = env.reset()
total_reward = 0
done = False
while not done:
# ε-탐욕적 행동 선택
action = agent.select_action(state)
next_state, reward, done, info = env.step(action)
# 경험 저장 및 업데이트 호출
agent.store_transition(state, action, reward, next_state, done)
agent.update()
state = next_state
total_reward += reward
reward_history.append(total_reward)
# 20 에피소드마다 평균 보상 출력
if (ep+1) % 20 == 0:
avg_reward = np.mean(reward_history[-20:])
print(f"Episode {ep+1}, Avg Reward(last 20): {avg_reward:.2f}, Epsilon: {agent.epsilon:.2f}")
env.close()
if __name__ == "__main__":
train_dqn()
코드 해설
- QNetwork: 상태 → Q값을 출력하는 간단한 신경망. CartPole 예제로 state_dim=4, action_dim=2일 경우, 입력 4차원, 출력 2차원 Q값 벡터.
- ReplayBuffer: (s, a, r, s', done) 형식의 경험을 저장. sample() 함수로 무작위 배치 추출. 이를 통해 시계열 상관성을 줄이고, 안정적 미니배치 학습 가능.
- DQNAgent:
- select_action: ε-탐욕적 정책. 초기엔 ε=1.0으로 거의 랜덤 행동. 학습 진행하며 ε 감소 → 점차 최적 행동 선택 비율 증가.
- store_transition: 환경에서 얻은 경험을 버퍼에 적재.
- update: 버퍼에서 미니배치 뽑아 Q-learning 타겟 계산 후 Q네트워크 파라미터 업데이트. 일정 스텝마다 타겟 네트워크 동기화로 안정성 확보.
- ε 감소 로직으로 학습이 진행될수록 탐색 줄이기.
- train_dqn:
- 여러 에피소드를 돌며 DQNAgent를 통해 강화학습 진행.
- 20 에피소드마다 평균 보상을 출력해 성능 개선 추이 확인. 수십 에피소드 지나면 평균 보상이 상승하는 경향을 볼 수 있음.
학습 결과
이 코드를 실행하면 초기에는 평균 보상이 낮지만, 시간이 지날수록 CartPole을 더 오래 유지하는 행동을 배우며 평균 보상이 높아집니다. 이로써 랜덤 행동과 달리 학습된 Q함수를 통해 정책을 개선할 수 있음을 확인할 수 있습니다.
물론 이 예제는 기본 형태의 DQN으로, 최적화나 파라미터 조정이 충분치 않아 완벽한 결과는 아니지만, DQN 알고리즘이 작동하는 기초 메커니즘을 직접 경험해 보는 데 의의가 있습니다.
마무리
이번 글에서는 경험 리플레이, 타겟 네트워크, ε-탐욕적 정책이라는 DQN 핵심 개념을 구현하여, CartPole 환경에서 실제로 Q함수를 학습하는 코드를 살펴보았습니다. 이를 통해 강화학습 알고리즘이 단순 랜덤 행동에서 벗어나, 경험을 적극적으로 활용하며 성능을 향상하는 과정을 체감할 수 있습니다.
다음 글에서는 DQN을 더 개선하거나, 다른 RL 알고리즘 변형(Double DQN, Dueling DQN, Prioritized Replay 등)을 도입해 성능 향상 방법을 논의할 수 있습니다. 또 다른 환경에 적용해보면서 RL 알고리즘 구현 경험을 확장할 수 있습니다.
'개발 이야기 > PyTorch (파이토치)' 카테고리의 다른 글
[PyTorch로 시작하는 강화학습 입문] 5편: Dueling DQN 구현으로 Q함수 구조 개선하기 (0) | 2024.12.13 |
---|---|
[PyTorch로 시작하는 강화학습 입문] 4편: DQN 개선하기 – Double DQN 구현 및 추가 변형 소개 (0) | 2024.12.12 |
[PyTorch로 시작하는 강화학습 입문] 2편: 가치 기반 접근과 Q함수 개념, PyTorch 신경망으로 Q함수 근사하기 (1) | 2024.12.11 |
[PyTorch로 시작하는 강화학습 입문] 1편: 강화학습과 PyTorch 소개, 개발환경 준비, 그리고 첫 실행 예제 (1) | 2024.12.11 |
[LibTorch 입문] 8편: 전체 구조 정리 및 마무리, 그리고 다음 단계 제안 (1) | 2024.12.11 |