강화학습에서는 에이전트가 상태(State)마다 어떤 행동(Action)을 취해야 하는지 결정하기 위해 정책(Policy)을 개선해나갑니다. 앞서 1편에서 살펴본 랜덤 정책 에이전트는 전혀 학습을 하지 않고, 그저 무작위로 행동을 선택하기 때문에 성능이 저조했습니다. 이제는 "가치(Value)" 개념을 도입하여, 각 상태-행동 쌍이 얼마나 좋은지(미래 보상을 많이 얻을 수 있는지)를 평가하는 방식으로 정책을 개선할 실마리를 잡아봅시다.
이번 글의 목표는 다음과 같습니다.
- 가치기반(Value-based) 접근 개념 정리: Q함수(Q-value)를 통해 상태-행동 쌍의 가치를 정의
- Q학습(Q-learning) 아이디어 소개: 벨만(Bellman) 방정식을 이용한 Q함수 업데이트 개념(이론적 상세는 추가 자료 참조)
- PyTorch로 Q함수를 근사하는 간단한 신경망 모델 구현 연습: 아직 완성된 DQN 알고리즘은 아니지만, Q함수를 딥뉴럴넷으로 근사할 준비를 마침
이 글에서는 이론을 필요한 핵심만 짚고, 실제 코드에서는 Q함수 추정을 위한 신경망 구조를 미리 설계해볼 것입니다. 추후 글에서 이 신경망을 경험 리플레이(Replay Buffer), 타겟 네트워크(Target Network) 등과 결합하여 DQN(Deep Q-Network)을 구현할 예정입니다.
가치(Value) 기반 접근이란?
정책을 직접 파라미터화하기보다는, "가치 함수(Value Function)" 또는 "Q함수(Q-value Function)"를 학습하는 접근을 가치기반 방법(Value-based method)라고 합니다. 여기서 Q함수는 상태 s에서 행동 a를 취했을 때, 앞으로 얻을 수 있는 예상 누적보상(expected return)을 나타냅니다.
간단히 말해, Q함수 Q(s, a)는 "이 상태 s에서 행동 a를 하면 얼마나 좋은가?"를 정량화한 것입니다. 가치기반 접근의 핵심 아이디어는 다음과 같습니다.
- Q함수를 정확히 알면, 각 상태에서 가치가 가장 높은 행동을 선택하는 정책을 쉽게 만들 수 있음
- 이론적으로 벨만 방정식(Bellman Equation)에 따라 Q함수를 업데이트하면 최적의 Q함수로 수렴
- Q함수를 표 형태로(상태, 행동 공간이 작다면) 저장할 수도 있지만, 환경이 복잡하면 상태공간이 매우 커짐 → 신경망으로 Q함수를 근사(Approximation)해서 확장성 확보
Q함수와 Q학습(Q-learning) 아이디어
Q학습(Q-learning)은 벨만 최적 방정식을 이용해 Q함수를 갱신하는 오프폴리시(Off-policy) 알고리즘입니다. 핵심 개념은 다음과 같습니다.
- 벨만 최적 방정식:
Q*(s,a) = E[ r + γ max_{a'} Q*(s', a') ]
여기서 s'은 다음 상태, r은 보상, γ는 할인율(0~1), max_{a'}는 다음 상태에서 가장 높은 Q값을 주는 행동 선택 - Q학습은 경험한 (s, a, r, s') 샘플을 이용해 다음과 같이 Q를 업데이트:
Q(s,a) ← Q(s,a) + α [ r + γ max_{a'}Q(s',a') - Q(s,a) ]
여기서 α는 학습률(Learning rate)
이론적 상세나 증명은 여기서 다루지 않지만, 핵심은 Q함수를 개선하면 더 좋은 정책(가치가 높은 행동을 고르는 정책)을 손쉽게 얻을 수 있다는 점입니다.
추가 참고 자료(이론):
- Reinforcement Learning: An Introduction (Sutton & Barto) - RL 기본 개념 및 Q학습 설명
- 유튜브: David Silver의 RL 강의 - Q-learning 개념 강의
PyTorch로 Q함수 근사용 신경망 구현하기
실제 환경에서 상태공간은 연속적이거나 매우 커서 표 형태로 Q함수를 관리하기 힘듭니다. 이때 딥러닝 모델(MLP, CNN 등)을 이용해 Q(s,a)를 근사합니다. 상태를 입력으로 받아 각 행동에 대한 Q값을 출력하는 신경망을 하나 만든다고 가정해봅시다. 예를 들어, CartPole에서는 상태가 4차원 벡터(카트 위치, 속도, 막대 각도, 각속도)이고, 행동이 2가지(왼/오)이라면, 입력 크기=4, 출력 크기=2인 MLP를 만들 수 있습니다.
아래 예제 코드는 CartPole 환경용 Q함수 근사 신경망 예시입니다. 아직 Q학습 또는 DQN을 완성하지 않고, 단지 네트워크 아키텍처를 정의하고, 예시 상태 텐서를 넣어 Q값을 추론하는 것까지만 해봅니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_size=64):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
q_values = self.fc3(x)
return q_values
if __name__ == "__main__":
# CartPole 상태 차원: 4, 행동 차원: 2
state_dim = 4
action_dim = 2
q_net = QNetwork(state_dim, action_dim)
# 예제 상태 텐서(batch_size=1)
sample_state = torch.randn((1, state_dim))
q_vals = q_net(sample_state)
print("예제 상태에 대한 Q값:", q_vals)
print("Q값 텐서 크기:", q_vals.shape)
위 코드를 q_network_example.py로 저장 후 실행하면, 임의의 상태를 입력받아 Q망이 반환한 [Q(s,a_1), Q(s,a_2)] 형태의 출력 텐서를 볼 수 있습니다. 아직 이 값들은 무의미한 랜덤 초기화 상태이지만, 나중에 Q학습 또는 DQN 알고리즘으로 학습시키면 상태에 따라 가치가 다른 행동을 식별할 수 있게 됩니다.
Q함수 신경망 구조 해석
- state_dim: 상태 벡터 크기. CartPole은 4개 값으로 표현되므로 4
- action_dim: 가능한 행동 개수. CartPole에서는 왼/오 2개
- 2개의 은닉층(hidden layer)을 갖는 심플한 MLP.
- 입력: 상태 텐서 (N, state_dim) 형태
- 출력: (N, action_dim) 형태의 Q값 벡터. 각 행동에 대한 Q값을 동시에 출력
이런 구조를 사용하면, 한 번의 전파(forward pass)로 모든 행동의 Q값을 얻을 수 있어 argmax를 쉽게 계산할 수 있습니다. 추후 알고리즘 구현 시, torch.max(q_values, dim=1)로 가장 가치 높은 행동을 손쉽게 찾을 수 있습니다.
이번 글 정리 및 다음 단계
이번 글에서는 가치기반 강화학습 접근법과 Q함수의 개념을 소개하고, PyTorch로 Q함수를 근사하는 신경망을 구현하는 기초를 다뤘습니다. 아직 Q함수를 학습하지 않았지만, 이로써 Q값 근사 기법을 위한 준비는 마쳤습니다. 다음 글부터는 경험 리플레이(Replay Buffer), 타겟 네트워크(Target Network) 등의 개념을 도입하고, DQN(Deep Q-Network) 알고리즘을 단계별로 구현해보며 실제로 Q함수를 학습시키는 과정을 살펴볼 것입니다.
Q함수를 근사하는 신경망을 이해하고, 상황에 따라 이 신경망이 어떻게 업데이트되는지 감을 잡아두면, 이후 DQN 구현 과정에서 한결 수월하게 접근할 수 있습니다.
'개발 이야기 > PyTorch (파이토치)' 카테고리의 다른 글
[PyTorch로 시작하는 강화학습 입문] 4편: DQN 개선하기 – Double DQN 구현 및 추가 변형 소개 (0) | 2024.12.12 |
---|---|
[PyTorch로 시작하는 강화학습 입문] 3편: DQN(Deep Q-Network) 기초 구현 – 경험 리플레이와 타겟 네트워크 (1) | 2024.12.11 |
[PyTorch로 시작하는 강화학습 입문] 1편: 강화학습과 PyTorch 소개, 개발환경 준비, 그리고 첫 실행 예제 (1) | 2024.12.11 |
[LibTorch 입문] 8편: 전체 구조 정리 및 마무리, 그리고 다음 단계 제안 (1) | 2024.12.11 |
[LibTorch 입문] 7편: C++/Python 통합 모델 추론 파이프라인 실습 (2) | 2024.12.11 |