나무모에 미러 (일반/밝은 화면)
최근 수정 시각 : 2025-12-05 15:15:49

JAX(라이브러리)


머신러닝 라이브러리
Accord.NET · Flax · JAX · Keras · MaxText · ML.NET · MLX · NumPy · PyTorch · TensorFlow · XLA
<colcolor=#000000,#ffffff> JAX
파일:jax_logo.svg
버전 0.8.1
2025년 11월 19일 업데이트
공개일 2020년 5월 8일
파일:홈페이지 아이콘.svg | 파일:GitHub 아이콘.svg파일:GitHub 아이콘 화이트.svg


1. 개요2. 특징3. 장점4. 단점5. 샘플 코드
5.1. 기본 예제5.2. JIT 컴파일5.3. vmap(자동 벡터화)
6. 다른 딥러닝 프레임워크 비교7. 활용 사례8. 비판 및 논란9. 같이 보기


1. 개요

JAX는 고성능 수치 계산과 대규모 머신러닝을 위해 설계된, 가속기[1] 친화적인 배열[2] 계산 및 프로그램 변환을 위한 Python 라이브러리입니다.
공식 깃헙 레포 소개글의 첫 문장[원문]
JAX (Just Another XLA)는 구글이 만든 신경망 기반 기계학습 라이브러리로, 자동미분(autograd)과 XLA(Accelerated Linear Algebra)를 결합하여 CPU, GPU, TPU에서 고성능 연산을 수행할 수 있도록 설계되었다.

2. 특징

패키지 설명
jax.grad 자동 미분
jax.jit JIT 컴파일
jax.vmap 자동 벡터화
jax.pmap 여러 디바이스 병렬 처리

3. 장점

4. 단점

5. 샘플 코드

5.1. 기본 예제

#!syntax python
import jax
import jax.numpy as jnp

# 함수 정의
def f(x):
    return jnp.sin(x) * jnp.cos(x)

# 자동 미분
grad_f = jax.grad(f)

print(grad_f(1.0))   # → 0.23924997

5.2. JIT 컴파일

#!syntax python
import jax
import jax.numpy as jnp

@jax.jit
def matmul(x, y):
    return jnp.dot(x, y)

x = jnp.ones((1000, 1000))
y = jnp.ones((1000, 1000))

print(matmul(x, y))   # 첫 실행에서 런타임 단계 컴파일이 수행되고 나서 부터는 실행시 굉장히 빠르다.

5.3. vmap(자동 벡터화)

#!syntax python
import jax
import jax.numpy as jnp

def square(x):
    return x * x

v_square = jax.vmap(square)

print(v_square(jnp.arange(5)))
# [0, 1, 4, 9, 16]

6. 다른 딥러닝 프레임워크 비교

<rowcolor=#fff> 항목 JAX PyTorch TensorFlow
목적 연구 / TPU / 대규모 학습 산업 적용 / 연구 / Serving 산업 / 모바일 / 배포
속도 매우 빠름 (XLA + JIT) 빠름 안정적(컴파일 시)
생태계 비교적 적음(성장중) 가장 넓음 넓은 편
TPU 지원 가장 강력 약함 강함
GPU 지원 매우 좋음 최고 수준 좋음
코드 작성 방식 함수형 / NumPy 기반 동적 파이썬 스타일 정적 컴파일 중심적이며 일부 동적 코드 지원
난이도 중-상 낮음

7. 활용 사례

8. 비판 및 논란

9. 같이 보기


[1] 이 문맥에서는 GPU, TPU 등 대량의 계산을 병렬적으로 빠르게 처리하기 위해 만들어진 컴퓨터 부품을 뜻한다.[2] 행렬(수학), 쉽게 말하자면 숫자들의 묶음이라고 생각해도 좋다.[원문] JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.