[PyTorch로 시작하는 강화학습 입문] 10편: 연속 행동 공간에 도전 – Soft Actor-Critic(SAC) 소개 및 기초 구현

이전까지는 CartPole처럼 왼/오 행동을 선택하는 이산적 행동 공간 문제를 다뤘습니다. 하지만 실제 응용(로봇 제어, 자율주행, 제어 시스템)은 연속적 행동(예: 모터 토크, 휠 각도)을 요구합니다. 이산적 행동 공간용 Q-learning 계열 알고리즘을 그대로 적용하기 어렵기 때문에, 연속 행동 공간에 맞는 알고리즘이 필요합니다.

SAC(Soft Actor-Critic)는 연속 행동 공간을 다루는 최신 Actor-Critic 알고리즘 중 하나로, 다음과 같은 특징을 갖습니다.

  • Off-policy Actor-Critic: 리플레이 버퍼를 사용, 데이터 효율적
  • 자동 온도 파라미터 조정: 탐사(Exploration)와 활용(Exploitation) 사이의 균형을 맞추는 엔트로피 보상(Entropy Regularization) 사용. 이로써 정책이 항상 일정 수준의 탐험성을 유지
  • 안정적이고 뛰어난 성능: DDPG, TD3 등 이전 연속 행동 알고리즘에 비해 뛰어난 안정성과 성능을 보여줌

SAC는 다음과 같은 요소를 포함합니다.

  1. Actor(정책 πθ): 상태를 받아 행동 a를 확률적(가우시안 분포)으로 샘플링
  2. Critic(Q함수 근사): 두 개의 Q네트워크(더 안정적인 학습을 위해 Double Q Trick)
  3. Entropy Regularization: 목표는 단순히 Q를 최적화하는 것이 아니라, 정책의 엔트로피를 높게 유지해 탐험성(Exploration)을 담보. 학습 과정에서 알파(α)라는 온도 파라미터를 자동 조정해 엔트로피 목표를 맞춤.

참고자료:

  • Haarnoja et al., "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor" https://arxiv.org/abs/1801.01290
  • OpenAI Spinning Up 자료, Stable Baselines3 구현 참고

이번 글에서는 무작위로 하나의 연속적 제어 환경을 예로 들겠습니다. 예: MountainCarContinuous-v0 (OpenAI Gym). 이 환경에서 SAC를 간단히 구현하는 과정을 보여주겠습니다. 실제 SAC 구현은 조금 복잡하나, 여기서는 기본 골격만 잡아 최소한의 코드를 제시합니다.

SAC 알고리즘 개요 (간략)

  1. Actor(정책 π): 상태 s에서 가우시안 분포로부터 행동 a 샘플. 행동을 환경에 적용.
  2. Critic(Q1, Q2): 두 개의 Q네트워크로 Q값 근사. 업데이트 시 min(Q1,Q2) 사용 → 과추정 완화.
  3. 온도 α 자동 조정: 엔트로피 목표보다 엔트로피가 낮으면 α 증가(탐험성↑), 높으면 α 감소.

업데이트 수식은 다음과 같음(자세한 수식은 여기서 생략):

  • Q-네트워크 업데이트: MSE 손실(r + γ( ... ) - Q(s,a))
  • 정책 업데이트: π를 업데이트할 때 클리핑 없이, 엔트로피 항을 포함하여 다음 형태의 목적함수 최소화:
    J(π) = E_s[ E_a[ α * log π(a|s) - Q(s,a) ] ]
  • α 업데이트: 엔트로피 목표 H̄를 설정하고, α 경사상승으로 목표 엔트로피에 맞추어 α 조정

예제 코드(간단화된 SAC)

아래 코드는 핵심 개념만 담은 매우 간단한 SAC 예제입니다. 실제로는 더 많은 하이퍼파라미터 튜닝, 안정화 테크닉, 정교한 초기화 등이 필요합니다.

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from collections import deque

###################################
# 연속 행동 정책 (Gaussian Policy)
###################################
class GaussianPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=256):
        super(GaussianPolicy, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.mean_layer = nn.Linear(hidden_size, action_dim)
        self.log_std_layer = nn.Linear(hidden_size, action_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.mean_layer(x)
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, -20, 2)  # 안정화
        std = torch.exp(log_std)
        return mean, std
    
    def sample(self, state):
        mean, std = self.forward(state)
        dist = torch.distributions.Normal(mean, std)
        action = dist.rsample()  # reparameterization trick
        # 타겟 환경이 [-1,1] 범위일 경우 tanh를 사용
        action_tanh = torch.tanh(action)
        log_prob = dist.log_prob(action) - torch.log(1 - action_tanh.pow(2) + 1e-7)
        return action_tanh, log_prob.sum(dim=-1)

###################################
# Q네트워크
###################################
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=256):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim+action_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
    
    def forward(self, s, a):
        x = torch.cat([s,a], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q = self.fc3(x)
        return q

###################################
# Replay Buffer
###################################
class ReplayBuffer:
    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, s,a,r,ns,d):
        self.buffer.append((s,a,r,ns,d))
    
    def sample(self,batch_size):
        batch = random.sample(self.buffer, batch_size)
        s,a,r,ns,d = zip(*batch)
        return np.array(s,dtype=np.float32),np.array(a,dtype=np.float32),np.array(r,dtype=np.float32),np.array(ns,dtype=np.float32),np.array(d,dtype=np.float32)
    
    def __len__(self):
        return len(self.buffer)

###################################
# SAC 에이전트
###################################
class SACAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, lr=3e-4, tau=0.005, alpha=0.2, batch_size=64):
        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha
        self.batch_size = batch_size
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.policy = GaussianPolicy(state_dim, action_dim).to(self.device)
        self.q1 = QNetwork(state_dim, action_dim).to(self.device)
        self.q2 = QNetwork(state_dim, action_dim).to(self.device)
        
        self.q1_target = QNetwork(state_dim, action_dim).to(self.device)
        self.q2_target = QNetwork(state_dim, action_dim).to(self.device)
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())
        
        self.policy_opt = optim.Adam(self.policy.parameters(), lr=lr)
        self.q1_opt = optim.Adam(self.q1.parameters(), lr=lr)
        self.q2_opt = optim.Adam(self.q2.parameters(), lr=lr)
        
        self.replay_buffer = ReplayBuffer()

    def select_action(self, state, eval_mode=False):
        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        if eval_mode:
            with torch.no_grad():
                mean, std = self.policy.forward(state_t)
                action = torch.tanh(mean)
            return action.cpu().numpy()[0]
        else:
            with torch.no_grad():
                action, _ = self.policy.sample(state_t)
            return action.cpu().numpy()[0]

    def store_transition(self, s,a,r,ns,d):
        self.replay_buffer.push(s,a,r,ns,d)

    def soft_update(self, net, net_t):
        for param, param_t in zip(net.parameters(), net_t.parameters()):
            param_t.data.copy_(self.tau*param.data + (1-self.tau)*param_t.data)

    def update(self):
        if len(self.replay_buffer)<self.batch_size:
            return
        s,a,r,ns,d = self.replay_buffer.sample(self.batch_size)
        
        s_t = torch.FloatTensor(s).to(self.device)
        a_t = torch.FloatTensor(a).to(self.device)
        r_t = torch.FloatTensor(r).unsqueeze(-1).to(self.device)
        ns_t = torch.FloatTensor(ns).to(self.device)
        d_t = torch.FloatTensor(d).unsqueeze(-1).to(self.device)
        
        # Next action, logprob
        with torch.no_grad():
            na, na_logp = self.policy.sample(ns_t)
            q1_target_val = self.q1_target(ns_t, na)
            q2_target_val = self.q2_target(ns_t, na)
            min_q_target = torch.min(q1_target_val,q2_target_val) - self.alpha*na_logp.unsqueeze(-1)
            y = r_t + self.gamma*(1-d_t)*min_q_target
        
        # Q1,Q2 업데이트
        q1_val = self.q1(s_t,a_t)
        q2_val = self.q2(s_t,a_t)
        q1_loss = F.mse_loss(q1_val,y)
        q2_loss = F.mse_loss(q2_val,y)
        
        self.q1_opt.zero_grad()
        q1_loss.backward()
        self.q1_opt.step()
        
        self.q2_opt.zero_grad()
        q2_loss.backward()
        self.q2_opt.step()
        
        # Policy 업데이트
        a_sample, logp = self.policy.sample(s_t)
        q1_val_new = self.q1(s_t,a_sample)
        q2_val_new = self.q2(s_t,a_sample)
        min_q_val_new = torch.min(q1_val_new,q2_val_new)
        policy_loss = (self.alpha*logp - min_q_val_new).mean()
        
        self.policy_opt.zero_grad()
        policy_loss.backward()
        self.policy_opt.step()
        
        # 타겟 네트워크 soft update
        self.soft_update(self.q1, self.q1_target)
        self.soft_update(self.q2, self.q2_target)

def train_sac(env_name="MountainCarContinuous-v0", max_episodes=300):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    agent = SACAgent(state_dim, action_dim)
    reward_history = []
    for ep in range(max_episodes):
        s = env.reset()
        total_reward = 0
        done = False
        while not done:
            a = agent.select_action(s)
            ns, r, done, info = env.step(a)
            agent.store_transition(s,a,r,ns,done)
            s = ns
            total_reward += r
            agent.update()
        
        reward_history.append(total_reward)
        if (ep+1)%20==0:
            avg_reward = np.mean(reward_history[-20:])
            print(f"Episode {ep+1}, Avg Reward(last 20): {avg_reward:.2f}")
    
    env.close()

if __name__=="__main__":
    train_sac()

코드 해설

  • GaussianPolicy: 상태→(mean,log_std)로 가우시안 정책 정의, Tanh로 동작 범위를 [-1,1]로 제한.
  • QNetwork: Q(s,a) 근사, 두 개의 Q네트워크 사용.
  • SACAgent:
    • Off-policy 데이터(ReplayBuffer) 사용
    • 업데이트 시 min(Q1,Q2) - αlogπ(a|s) 형태의 타겟 사용
    • Policy 업데이트 시 엔트로피 항(logπ) 반영
  • 여기서는 α 고정이나, 실제 SAC는 α를 learnable parameter로 설정해 자동 조정 가능(논문 참고)

마무리

이번 글에서는 SAC를 소개하고, 연속적 행동 공간에서도 안정적이고 강력한 RL 알고리즘을 구현하는 방법을 간단히 살펴봤습니다. SAC는 현재 로보틱스, 제어 문제 등 다양한 분야에서 큰 인기를 얻고 있으며, PPO 등과 함께 RL 알고리즘 선택 시 강력한 후보가 됩니다.

다음 글에서는 시리즈를 마무리하며, 지금까지 다룬 알고리즘을 정리하고, 실제 프로젝트 적용 시 고려사항, 추가 학습 자료, 향후 나아갈 방향 등을 제안하겠습니다.

반응형