개발 이야기/Rust (러스트)

러스트로 만드는 수학 및 과학 계산 라이브러리 시리즈 - 9편: 자동미분(AD)과 최적화 알고리즘, 머신러닝 기본기

nodiscard 2025. 1. 2. 17:00

이전까지의 글에서 우리는 선형대수, FFT, ODE/PDE 풀이, FEM 기초, GPU 가속, 병렬 처리 등 다양한 수학·과학 계산 기법을 러스트로 구현하고 고성능화하는 가능성을 살펴보았습니다. 이제는 **자동미분(Automatic Differentiation, AD)**과 최적화 알고리즘, 나아가 머신러닝(ML)의 기본 구조를 다루며, 러스트 기반 수학 라이브러리를 머신러닝 분야로 확장할 수 있는 기초를 마련해보겠습니다.

 

이번 글에서 다룰 주제들:

  1. 자동미분(AD): 함수의 기울기(Gradient)·야코비안(Jacobian)·헤시안(Hessian)을 기호적 변화 없이 자동으로 계산하는 기법.
  2. 최적화 알고리즘(Gradient Descent, Newton Method 등): 파라미터를 조정해 목표 함수를 최소화/최대화하는 과정.
  3. 머신러닝 기본기: 선형회귀(Linear Regression)나 로지스틱 회귀(Logistic Regression) 문제를 예로 들어 AD와 최적화 알고리즘을 결합하는 방법을 살펴봅니다.

자동미분(AD) 기초

AD는 주어진 함수 f(x)의 도함수를 계산하기 위해, 프로그램 실행 과정에서 연산 그래프(Computation Graph)를 추적하고 각 연산에 대한 미분을 체인룰로 적용하여 기울기를 구합니다. **순전파(forward-mode)**와 역전파(reverse-mode) AD가 있으며, 머신러닝에서는 대규모 다변수 함수를 효율적으로 미분하기 위해 역전파 AD가 널리 쓰입니다.

러스트에서 AD를 구현하는 한 가지 방법:

  • Traits 기반의 오버로딩: 수 x를 값뿐 아니라 도함수를 담는 구조체로 감싸고, +, -, *, / 연산 시 도함수까지 전파. (Forward-mode)
  • 컴퓨테이션 그래프 구성: 연산자 실행 시 그래프 노드 생성, 역전파 시 그래프를 순회하며 도함수 계산. (Reverse-mode)

여기서는 간단한 forward-mode AD 예제를 스케치해봅시다.

// src/autodiff.rs (간단 forward-mode AD 예)
#[derive(Copy, Clone)]
pub struct Dual {
    pub val: f64,
    pub der: f64,
}

impl Dual {
    pub fn new(val: f64, der: f64) -> Self {
        Self { val, der }
    }
}

use std::ops::{Add,Mul};

impl Add for Dual {
    type Output = Dual;
    fn add(self, rhs: Dual) -> Dual {
        Dual::new(self.val+rhs.val, self.der+rhs.der)
    }
}

impl Mul for Dual {
    type Output = Dual;
    fn mul(self, rhs: Dual) -> Dual {
        // (f*g)' = f'*g + f*g'
        Dual::new(self.val*rhs.val, self.der*rhs.val + self.val*rhs.der)
    }
}

// 이런 식으로 sin, cos 등 다른 함수에 대해서도 chain rule 적용
pub fn sin_dual(x: Dual) -> Dual {
    Dual::new(x.val.sin(), x.der*x.val.cos())
}

이렇게 하면 f(x)=sin(x)*x 의 도함수를 x=1에서 구하려면 f'(1)을 let x = Dual::new(1.0,1.0); let y = sin_dual(x)*x; y.der 로 얻을 수 있습니다.

역전파 AD는 더 복잡하며, 실제 ML 프레임워크에서 DAG(Directed Acyclic Graph)를 구성하고 역전파 시 노드별 partial derivative를 누적해 나가는 구조를 구현합니다.

최적화 알고리즘

AD로 기울기를 얻으면 이를 활용해 함수 최소화/최대화를 위한 최적화 알고리즘을 적용할 수 있습니다.

예: Gradient Descent (GD)

// src/optimize.rs (개략적 예제)
pub fn gradient_descent<F,G>(mut x: f64, f: F, grad: G, lr: f64, steps: usize) -> f64
where
    F: Fn(f64)->f64,
    G: Fn(f64)->f64
{
    for _ in 0..steps {
        let g = grad(x);
        x = x - lr*g;
    }
    x
}

복잡한 고차원 문제에서는 x가 벡터, grad가 벡터를 반환하는 구조를 사용할 것이며, AD를 통해 grad를 자동으로 계산할 수 있습니다.

Newton 방법, L-BFGS, Adam 등 더 정교한 최적화 알고리즘을 구현하면 비선형 최적화, 머신러닝 모델 파라미터 훈련 등에 활용할 수 있습니다.

머신러닝 기본기: 선형회귀 예제

선형회귀: y ≈ wx + b 라고 할 때, 손실함수 L(w,b) = ∑(y_i - (wx_i+b))² 최소화. AD를 사용해 L에 대한 w,b의 기울기를 구하고, gradient descent로 w,b를 업데이트하는 과정을 구현할 수 있습니다.

// src/ml_basic.rs (개념 예제)
pub fn linear_regression_fit(xs: &[f64], ys: &[f64], lr: f64, steps: usize) -> (f64,f64) {
    let mut w = 0.0;
    let mut b = 0.0;

    for _ in 0..steps {
        let mut dw=0.0;
        let mut db=0.0;
        for (&x,&y) in xs.iter().zip(ys.iter()) {
            let pred = w*x+b;
            let err = pred - y;
            dw += 2.0*err*x;
            db += 2.0*err;
        }
        dw /= xs.len() as f64;
        db /= xs.len() as f64;

        w -= lr*dw;
        b -= lr*db;
    }

    (w,b)
}

여기서는 수작업으로 기울기를 구했지만, AD를 활용하면 모델을 변경해도 기울기 계산 코드를 다시 작성할 필요가 없습니다. 역전파 AD를 사용하면 복잡한 딥러닝 모델의 기울기를 자동으로 얻을 수 있어, 러스트로 ML 프레임워크를 구축하는 기반이 됩니다.

향후 발전 방향

  • 역전파 AD 구현: 그래프 기반으로 다양한 연산 지원, 텐서(tensor) 구조와 결합
  • ML 모델 라이브러리 구축: Dense Layer, Convolution Layer, RNN 등 연산 정의
  • GPU 가속과 결합: AD + GPU 가속을 통해 딥러닝 모델 훈련
  • 대규모 데이터 처리 파이프라인: 분산 처리, DataLoader, 데이터셋 전처리 도입
  • Symbolic 연산 + AD 결합: 일부 부분을 심볼릭으로 간소화하고 나머지를 AD로 처리

결론

이번 글에서 자동미분(AD), 최적화 기법, 머신러닝 기본 아이디어를 살펴보며, 러스트로 구축하는 수학/과학 라이브러리가 단순한 수학 연산을 넘어 머신러닝·최적화 프레임워크로 확장될 수 있음을 확인했습니다. 러스트의 안정성과 성능, 모듈성 덕분에 차세대 머신러닝 라이브러리를 러스트로 구현하는 가능성을 모색해볼 수 있습니다.

유용한 링크와 리소스

  • AD 개념: Griewank & Walther "Evaluating Derivatives", Autograd 논문, PyTorch 코드베이스 참고
  • Rust AD Crates: autograd, oxide-enzyme (Enzyme LLVM 기반 AD)
  • Rust ML Libraries: linfa (Classical ML), burn (Neural network), tch-rs (PyTorch FFI)
반응형