반응형

JAX 란?

  • JAX란, 머신러닝의 연산을 가속화하기 위해, Google에서 발표한 컴파일러를 최적화하는 기술이다.
  • JAX는 머신러닝에서 필수적인 Autograd와 XLA(integrated with Accelerated Linear Algebra, 가속 선형 대수) 컴파일러를 통해, 머신러닝의 연산을 빠르게 실행해 준다.
  • JAX는 설치가 매우 쉽고, 기존 Python에서 구현된 Numpy를 쉽게 변환할 수 있어서, 많이 활용되고 있다. 
  • 다만 JAX는 구글의 공식 제품이 아닌, 연구 프로젝트 기 때문에, 아직은 이런저런 버그가 있을 수 있다고 한다.

 

JAX 설치 방법

  • JAX는 우선 기본적으로 Linux나 Mac 운영 체제에서 동작한다.
  • Window도 동작하기는 하지만, 실험버전으로 CPU를 활용한 jax만 지원된다. (WSL을 사용하면 GPU를 사용할 수 있긴 하다.)

[CPU 설치]

pip install --upgrade "jax[cpu]"

 

[GPU & TPU 설치]

  • GPU에서도 pypi를 통해 쉽게 설치가 가능하다. 하지만, GPU는 Linux 환경에서만 설치되는 것을 명심하자. (나의 경우에는 WSL로 진행했다.)
# CUDA 12 
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# CUDA 11
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

 

JAX 기본 기능

jax.numpy

  • jax는 기본적으로 jax.numpy를 통해, numpy의 API를 그대로 호환해 준다.
  • jax.numpy와 numpy는 거의 비슷하지만, 차이가 있는데,  jax는 함수형 프로그래밍으로 설계되어 있다는 점이다.
  • 즉, numpy는 배열에 직접 접근해서, 값을 바꾸는 것이 허용되지만, jax.numpy는 데이터를 직접 조작하는 것이 허용되지 않는다. → 거의 모든 Python 가속기들의 특징인 것 같다.
  • 다만, 값을 직접 바꾸는 것은 불가능하지만, 해당 요소를 반영한 새로운 배열을 생성할 수 있다.
import jax.numpy as jnp

if __name__ == '__main__':
    data = jnp.array([1,2,3,4])
    data[0] = 5
    # ERROR

    data = data.at[0].set(5)
    # data = [5,2,3,4]

 

 

[grad]

  • JAX는 native Python 및 numpy 코드를 자동으로 미분할 수 있는 기능을 제공한다.
  • JAX의 grad 함수는 함수의 입력에 대한 gradient를 자동으로 계산해 주는 함수이다.
  • JAX의 grad 함수는 loss의 기울기를 구할 때, 매우 빠르고 쉽게 활용될 수 있다.
  • JAX의 grad는 N차 미분값까지 쉽게 구할 수 있다.
import jax
import jax.numpy as jnp

def square(x):
    return x ** 2

if __name__ == '__main__':
    grad_square = jax.grad(square)

    # Calculate Gradient
    x = jnp.array(2.0)
    grad_value = grad_square(x)

    print("Input:", x)
    print("Gradient:", grad_value)

 

 

[jit]

  • jax.jit 함수는 JAX에서 제공하는 함수를 최적화해 주는 메커니즘이다. 
  • jit 함수를 통해, 정의한 함수를 컴파일하여, 최적화된 코드로 변환하고, 이를 Cache에 저장해 둔 뒤, 호출 시, 최적화된 코드를 통해 빠르게 실행된다.
  • 최적화된 Code를 Cache에 저장해 두기 때문에, 반복 변환이나, 불필요한 변환은 피하는 것이 좋다.
  • 다만, jit은 아래와 같은 경우에는 속도 향상이 없거나, 오히려 늦어질 수 있다.
    • 변환하려는 함수 내에 제어문이 포함된 경우
    • 재귀함수 
    • 데이터를 조작하는 함수
    • 크고 복잡한 함수 → 변환을 위한 cost가 더 많이 들 수 있음
  • 다른 모듈처럼, jit 사용을 위해, 단순 decorator만 사용해 주면 된다. 하지만, 변환을 위한 cost가 더 많이 들 수 있기 때문에, 꼭 비교해 보고 사용하는 것이 좋다.
import jax
import jax.numpy as jnp

@jax.jit
def square(x):
    return x ** 2

if __name__ == '__main__':
    grad_square = jax.grad(square)

    # Calculate Gradient
    x = jnp.array(2.0)
    grad_value = grad_square(x)

    print("Input:", x)
    print("Gradient:", grad_value)

 

 

[vmap]

  • jax.vmap 함수는 함수를 Vector 화하여 mapping 하는 함수이다. 
  • vmap 함수를 통해, 배열의 각 요소에 함수를 병렬로 실행할 수 있다. (pandas의 apply와 비슷한 개념이다.)
  • jit과 vmap은 같이 사용될 수 있다. (jit을 먼저 래핑 한 후, vmap을 하거나, vmap을 래핑한 후, jit을 하거나 둘 다 가능하다.)
import jax
import jax.numpy as jnp

def dot_product(x, y):
    return jnp.dot(x, y)

if __name__ == '__main__':
    grad_square = jax.grad(dot_product)

    vectorized_dot_product = jax.vmap(dot_product)

    x = jnp.array([i for i in range(10000)])
    y = jnp.array([i for i in range(10000)])
    grad_value = dot_product(x, y)

 

 

JAX 사용후기

  • JAX는 기본적으로 multi GPU 환경이나, TPU 환경에서 유리하다. 나의 경우에는 single GPU 환경이기 때문에, JAX를 쓰면 오히려 변환에 더 오랜 시간이 걸렸다. (JAX가 분산에 최적화되었기 때문이다.)
  • JAX가 numpy를 호환한다고 하지만,  아직 torch 등의 딥러닝 프레임워크와 호환이 부족하다. 따라서, 단순 기존 코드의 최적화가 아닌, 분산 환경에서 속도를 향상시키기 위한 대대적 Refactoring이나 개발에 사용하는 것이 좋을 것 같다.
  • JAX는 현재 기준(2023.08.07)으로 CUDA 11버전까지만 지원한다. 이것도 환경을 제한하는 요소인 것 같다.
  • 그럼에도 불구하고, JAX는 딥러닝 코드를 Python 언어 내에서 최적화할 수 있는 선택지를 제공한다는 점에서 매우 유용한 것 같다.

 

 

+ Recent posts