Google Jax - 고성능 머신러닝 라이브러리
(github.com)"사용하기 쉬운 것을 빠르게 만들어서 머신러닝에 적용"
- Python과 Numpy만을 결합
ㅤ→ XLA 를 이용해서 Numpy를 GPU/TPU에서 컴파일하고 실행
ㅤ→ 파이썬 함수를 API 하나로 JIT 컴파일 해서 XLA 최적화된 커널에 쉽게 넣을 수 있음
ㅤ→ 다수의 GPU/TPU 에서의 실행도 쉽게 (vmap, pmap)
- 기존 파이썬+Numpy 성능을 훨씬 뛰어넘음
DeepMind는 Jax 기반으로 전체를 리팩토링 했음
https://deepmind.com/blog/article/using-jax-to-accelerate-our-research