반응형
JAX 란?
- JAX란, 머신러닝의 연산을 가속화하기 위해, Google에서 발표한 컴파일러를 최적화하는 기술이다.
- JAX는 머신러닝에서 필수적인 Autograd와 XLA(integrated with Accelerated Linear Algebra, 가속 선형 대수) 컴파일러를 통해, 머신러닝의 연산을 빠르게 실행해 준다.
- JAX는 설치가 매우 쉽고, 기존 Python에서 구현된 Numpy를 쉽게 변환할 수 있어서, 많이 활용되고 있다.
- 다만 JAX는 구글의 공식 제품이 아닌, 연구 프로젝트 기 때문에, 아직은 이런저런 버그가 있을 수 있다고 한다.
JAX 설치 방법
- JAX는 우선 기본적으로 Linux나 Mac 운영 체제에서 동작한다.
- Window도 동작하기는 하지만, 실험버전으로 CPU를 활용한 jax만 지원된다. (WSL을 사용하면 GPU를 사용할 수 있긴 하다.)
[CPU 설치]
pip install --upgrade "jax[cpu]"
[GPU & TPU 설치]
- GPU에서도 pypi를 통해 쉽게 설치가 가능하다. 하지만, GPU는 Linux 환경에서만 설치되는 것을 명심하자. (나의 경우에는 WSL로 진행했다.)
# CUDA 12
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# CUDA 11
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
JAX 기본 기능
jax.numpy
- jax는 기본적으로 jax.numpy를 통해, numpy의 API를 그대로 호환해 준다.
- jax.numpy와 numpy는 거의 비슷하지만, 차이가 있는데, jax는 함수형 프로그래밍으로 설계되어 있다는 점이다.
- 즉, numpy는 배열에 직접 접근해서, 값을 바꾸는 것이 허용되지만, jax.numpy는 데이터를 직접 조작하는 것이 허용되지 않는다. → 거의 모든 Python 가속기들의 특징인 것 같다.
- 다만, 값을 직접 바꾸는 것은 불가능하지만, 해당 요소를 반영한 새로운 배열을 생성할 수 있다.
import jax.numpy as jnp
if __name__ == '__main__':
data = jnp.array([1,2,3,4])
data[0] = 5
# ERROR
data = data.at[0].set(5)
# data = [5,2,3,4]
[grad]
- JAX는 native Python 및 numpy 코드를 자동으로 미분할 수 있는 기능을 제공한다.
- JAX의 grad 함수는 함수의 입력에 대한 gradient를 자동으로 계산해 주는 함수이다.
- JAX의 grad 함수는 loss의 기울기를 구할 때, 매우 빠르고 쉽게 활용될 수 있다.
- JAX의 grad는 N차 미분값까지 쉽게 구할 수 있다.
import jax
import jax.numpy as jnp
def square(x):
return x ** 2
if __name__ == '__main__':
grad_square = jax.grad(square)
# Calculate Gradient
x = jnp.array(2.0)
grad_value = grad_square(x)
print("Input:", x)
print("Gradient:", grad_value)
[jit]
- jax.jit 함수는 JAX에서 제공하는 함수를 최적화해 주는 메커니즘이다.
- jit 함수를 통해, 정의한 함수를 컴파일하여, 최적화된 코드로 변환하고, 이를 Cache에 저장해 둔 뒤, 호출 시, 최적화된 코드를 통해 빠르게 실행된다.
- 최적화된 Code를 Cache에 저장해 두기 때문에, 반복 변환이나, 불필요한 변환은 피하는 것이 좋다.
- 다만, jit은 아래와 같은 경우에는 속도 향상이 없거나, 오히려 늦어질 수 있다.
- 변환하려는 함수 내에 제어문이 포함된 경우
- 재귀함수
- 데이터를 조작하는 함수
- 크고 복잡한 함수 → 변환을 위한 cost가 더 많이 들 수 있음
- 다른 모듈처럼, jit 사용을 위해, 단순 decorator만 사용해 주면 된다. 하지만, 변환을 위한 cost가 더 많이 들 수 있기 때문에, 꼭 비교해 보고 사용하는 것이 좋다.
import jax
import jax.numpy as jnp
@jax.jit
def square(x):
return x ** 2
if __name__ == '__main__':
grad_square = jax.grad(square)
# Calculate Gradient
x = jnp.array(2.0)
grad_value = grad_square(x)
print("Input:", x)
print("Gradient:", grad_value)
[vmap]
- jax.vmap 함수는 함수를 Vector 화하여 mapping 하는 함수이다.
- vmap 함수를 통해, 배열의 각 요소에 함수를 병렬로 실행할 수 있다. (pandas의 apply와 비슷한 개념이다.)
- jit과 vmap은 같이 사용될 수 있다. (jit을 먼저 래핑 한 후, vmap을 하거나, vmap을 래핑한 후, jit을 하거나 둘 다 가능하다.)
import jax
import jax.numpy as jnp
def dot_product(x, y):
return jnp.dot(x, y)
if __name__ == '__main__':
grad_square = jax.grad(dot_product)
vectorized_dot_product = jax.vmap(dot_product)
x = jnp.array([i for i in range(10000)])
y = jnp.array([i for i in range(10000)])
grad_value = dot_product(x, y)
JAX 사용후기
- JAX는 기본적으로 multi GPU 환경이나, TPU 환경에서 유리하다. 나의 경우에는 single GPU 환경이기 때문에, JAX를 쓰면 오히려 변환에 더 오랜 시간이 걸렸다. (JAX가 분산에 최적화되었기 때문이다.)
- JAX가 numpy를 호환한다고 하지만, 아직 torch 등의 딥러닝 프레임워크와 호환이 부족하다. 따라서, 단순 기존 코드의 최적화가 아닌, 분산 환경에서 속도를 향상시키기 위한 대대적 Refactoring이나 개발에 사용하는 것이 좋을 것 같다.
- JAX는 현재 기준(2023.08.07)으로 CUDA 11버전까지만 지원한다. 이것도 환경을 제한하는 요소인 것 같다.
- 그럼에도 불구하고, JAX는 딥러닝 코드를 Python 언어 내에서 최적화할 수 있는 선택지를 제공한다는 점에서 매우 유용한 것 같다.
'Python' 카테고리의 다른 글
Pandas 데이터 구조 & 함수 정리 (2) | 2023.11.02 |
---|---|
Pypi 사용법 & 명령어 모음 & 폐쇄망 사용법 (1) | 2023.08.16 |
Pandas 성능 향상을 위한 방법들 (2) | 2023.07.21 |
Transformer Pytorch 구현 (11) | 2023.07.15 |
Pytorch Profiler Tensorboard로 시각화 (1) | 2023.07.10 |