# PyTorch는 죽었다. JAX 만세

> Clean Markdown view of GeekNews topic #16369. Use the original source for factual precision when an external source URL is present.

## Metadata

- GeekNews HTML: [https://news.hada.io/topic?id=16369](https://news.hada.io/topic?id=16369)
- GeekNews Markdown: [https://news.hada.io/topic/16369.md](https://news.hada.io/topic/16369.md)
- Type: news
- Author: [xguru](https://news.hada.io/@xguru)
- Published: 2024-08-19T11:06:02+09:00
- Updated: 2024-08-19T11:06:02+09:00
- Original source: [neel04.github.io](https://neel04.github.io/my-website/blog/pytorch_rant/)
- Points: 14
- Comments: 8

## Summary

PyTorch는 동적 실행과 디버깅에 강점을 두었지만, 최근의 성능과 확장성 요구에 부응하지 못해 한계를 드러내고 있습니다. JAX는 강력한 컴파일러 스택인 XLA를 활용하여 자동 병렬화, 재현성, 함수형 API 등에서 뛰어난 성능을 발휘하며, 다양한 하드웨어에서의 이식성도 높습니다. PyTorch의 멀티 백엔드 접근 방식은 복잡성과 비호환성 문제를 야기하는 반면, JAX는 일관된 컴파일러 기반 접근 방식으로 더 나은 사용자 경험을 제공합니다. 아직까지는 PyTorch가 연구자들한테 가장 인기있지만, JAX를 살펴보고 비교해서 선택하는 것을 추천합니다.

## Topic Body

- PyTorch가 생산성 손실과 개발 시간 낭비를 초래하는 이유는 "프레임워크 자체가 나쁘기 때문이 아니라, 현재 적용되는 유스케이스에 맞게 설계되지 않았기 때문"  
  
### PyTorch의 철학  
- PyTorch의 철학은 동적이고, 디버깅하기 쉽고, 파이썬스러움  
- 반면 TensorFlow 1.x는 XLA 컴파일러를 강력히 사용해 정적이지만 성능이 좋은 프레임워크가 되려고 함  
- TensorFlow 개발자들은 커뮤니티가 1.x API를 싫어한다는 것을 깨닫고 Keras를 메인 인터페이스로 사용하기로 결정하고 XLA 컴파일러의 역할을 축소함  
- PyTorch는 뿌리를 지켰고, TensorFlow의 정적이고 지연된 접근법과 달리 torch.Tensor가 즉시 평가되는 더 역동적인 "즉시 실행" 접근법을 채택함  
- 이게 성과를 내면서 많은 연구가 PyTorch로 옮겨감   
- 2021년 GPT-3가 등장하면서 성능과 확장성이 주요 관심사가 됨  
- PyTorch는 이러한 수요에 어느 정도 잘 대응했지만, 이러한 철학을 염두에 두고 설계되지 않았기 때문에 점점 부채가 쌓이고 기반이 흔들리기 시작함  
- PyTorch 개발자들은 어떤 타협점도 원하지 않았고 두 가지 경로를 동시에 추구하기로 선택함  
  - XLA 컴파일러를 성능과 안정성이 뛰어난 기본 백엔드로 사용  
  - torch.compile 스택을 구축하여 필요한 경우 사용자가 컴파일러를 호출할 수 있는 자유를 부여  
- 장기 전략의 부재는 심각한 문제임  
- PyTorch는 컴파일러 중심의 철학(JAX와 같은)에 전념하고 싶어 하지 않지만 좋은 대안이 보이지 않음  
- 이 문제에 대한 경쟁 제품들의 해결책은 ?  
  
### JAX의 컴파일러 기반 개발  
  
- JAX는 TensorFlow의 강력한 컴파일러 스택인 XLA를 활용함  
- XLA는 강력한 컴파일러이지만, 엔드 유저에게는 모두 추상화되어 있음  
- 함수가 순수(pure)하기만 하면 @jax.jit 데코레이터를 사용해 함수를 JIT 컴파일하고 XLA에서 사용할 수 있게 만들 수 있음  
- XLA는 생성된 그래프가 정확한지 검증하고, JAX에서 샤딩을 사용한 자동 병렬화를 처리하는 GSPMD 파티셔너, 그래프 최적화, 연산자 및 커널 융합, 대기 시간 숨김 스케줄링, 비동기 통신 오버랩, triton과 같은 다른 백엔드에 대한 코드 생성 등을 모두 뒤에서 처리함  
- JAX 제한 사항을 준수하기만 하면 XLA가 자동으로 처리해 줌  
- 예를 들어 병렬화할 때 torch.distributed.barrier()와 같은 통신 프리미티브가 필요하지 않음  
- DDP 지원은 간단한 코드로 가능함  
- XLA의 접근 방식은 계산이 샤딩을 따른다는 것임. 따라서 입력 배열이 어떤 축을 따라 샤딩되면 XLA는 하위 계산에 대해 자동으로 처리함  
- "컴파일러 기반 개발" 아이디어는 Rust 컴파일러의 작동 방식과 유사함   
- PyTorch의 한계  
  - PyTorch 개발자들이 유연성과 자유의 핵심 철학을 유지하는 대신 새로운 기능을 위해 컴파일러 스택을 통합하고 의존하기로 한 선택에 불만족스러움  
  - PyTorch 2.x의 공식 로드맵에 따르면 XLA를 Torch와 완전히 통합할 장기 계획을 명확히 제시하고 있음  
  - 이는 끔찍한 아이디어임. Rust 컴파일러에 C++ 코드를 억지로 끼워 넣는 것이 Rust 자체를 사용하는 것보다 더 나은 경험이 될 것이라고 말하는 것과 같음  
  - Torch는 JAX와 달리 XLA를 중심으로 설계되지 않았음  
  - PyTorch가 XLA 기반 컴파일러 스택을 사용하기로 결정한다면, 이상적인 프레임워크는 그것을 중심으로 특별히 설계되고 구축된 것이 아닐까?   
  - PyTorch가 원하는 컴파일러 백엔드를 선택할 수 있는 "멀티 백엔드" 접근 방식을 추구하더라도 조각화 문제를 악화시키고 모든 컴파일러 스택의 제한을 존중하려고 시도하면서 API를 절대적으로 망가뜨리지 않을까?  
  - Torch/XLA를 TPU에서 사용해 본 사람은 누구나 심각한 PTSD로 고통받음  
  
### Multi-Backend는 망했음   
- PyTorch는 한 번에 모든 것을 하려고 하면서 비참하게 실패함  
- "멀티 백엔드" 설계 결정은 이 문제를 기하급수적으로 악화시킴  
- 이론적으로는 원하는 스택을 선택할 수 있는 것처럼 들리지만, 실제로는 이해하기 어려운 트레이스백과 비호환성 문제의 엉킨 혼란임  
- 백엔드 간 제약 조건과 PyTorch API의 충돌  
  - 이러한 백엔드를 작동시키는 것 자체가 어려운 것이 아니라, 이 백엔드들이 기대하는 제약 조건이 PyTorch의 유연하고 Pythonic한 API와 잘 맞지 않음  
  - API의 일관성을 유지하는 것과 백엔드의 제한을 따르는 것 사이에는 트레이드오프가 있음  
  - 결과적으로 개발자들은 단일 백엔드와 실제로 통합/커밋하는 대신 코드 생성에 더 의존하려고 함  
- PyTorch의 전략 부재  
  - PyTorch는 의미 있는 트레이드오프를 거부하기 때문에 모든 결정이 타협처럼 느껴짐  
  - 일관성도, 전반적인 전략도 없음  
  - 궁극적으로 사용자에게 많은 좌절감을 야기하고 잘 어울리지 않는 기능들의 잡동사니처럼 느껴짐  
  - 생태계를 죽이는 더 빠른 방법은 없음  
- JAX 접근 방식을 따라서는 안 되는 이유  
  - PyTorch는 JAX의 "통합 컴파일러 및 백엔드" 접근 방식을 따라서는 안 됨  
  - JAX는 XLA와 함께 작동하도록 명시적으로 설계되었기 때문  
  - PyTorch 프론트엔드를 JAX의 것으로 교체하는 것이 전략이 될 수는 없음  
  - XLA를 기반으로 JAX보다 더 나은 API를 고안하는 것은 사실상 불가능함  
  - 개발자들이 새롭고 다른 아이디어를 시도하는 것을 비난하지는 않음  
  - 그러나 PyTorch가 시간의 시험을 견디려면, 이상적인 튜토리얼 조건 밖에서 즉시 무너지는 멋진 새 기능을 제공하는 것보다 기반을 강화하는 데 더 중점을 두어야 함  
  
### PyTorch의 파편화와 JAX의 함수형 프로그래밍  
- JAX의 함수형 API  
  - JAX 함수는 순수(pure)해야 함. 즉, 전역적인 부작용이 없어야 함  
  - 수학 함수처럼 동일한 데이터가 주어지면 실행 컨텍스트에 상관없이 항상 동일한 출력을 반환해야 함  
  - 이러한 설계 철학 덕분에 JAX 함수는 구성 가능하고 서로 잘 상호 운용됨  
  - 개발 복잡성이 줄어들고, 함수는 특정 시그니처와 잘 정의된 구체적인 작업으로 정의됨  
  - 타입이 지켜지면 함수는 즉시 작동할 것이 보장됨  
  - 이는 과학 계산, 특히 딥러닝에서 필요한 작업 유형에 적합함  
- optax API 예시  
  - 함수형 접근 방식 덕분에 optax에는 "체인"이라는 것이 있음  
  - 이는 그래디언트에 순차적으로 적용되는 여러 함수를 포함함  
  - 근본적인 구성 요소는 GradientTransformation임  
  - 강력하면서도 표현력 있는 API를 만듦  
  - 예를 들어 그래디언트를 클리핑하거나, 그래디언트의 EMA를 취하거나, 옵티마이저를 결합하는 등의 작업이 매우 간단해짐  
- 함수형 설계의 장점  
  - 함수형 설계의 또 다른 멋진 결과는 vmap임  
  - 이는 'vectorized' map을 의미하며 정확히 그 기능을 설명함  
  - 모든 것을 map할 수 있고, vmap이기만 하면 XLA가 자동으로 융합하고 최적화함  
  - 함수를 작성할 때 배치 차원을 생각할 필요가 없음  
  - 모든 코드를 vmap하기만 하면 됨  
  - 이는 ein-* 작업이 덜 필요하다는 것을 의미함  
  - 2D/3D 텐서 조작을 파악하는 것이 더 직관적이고 가독성도 훨씬 좋음  
  - 개별 구성 요소를 격리하여 추론하기만 하면 되므로 잘 작동하는 복잡한 코드를 더 쉽게 작성할 수 있음  
  - 순수성 제약 조건을 존중하고 올바른 시그니처만 있으면 구성 가능성과 같은 다른 모든 이점을 누릴 수 있음  
- PyTorch 생태계의 문제점  
  - torch에서는 사용하는 스택(FSDP + 다중 노드 + torch.compile 등)에 관계없이 항상 무언가 깨질 가능성이 있음  
  - 여러 가지가 올바르게 함께 작동해야 하며, 어떤 구성 요소라도 실패하면 오전 3시까지 디버깅해야 함  
  - PyTorch가 제공하는 수십 가지 기능의 모든 조합을 테스트할 수 없기 때문에 개발 중에 발견되지 않은 버그가 항상 있을 것임  
  - 상당한 노력 없이는 잘 작동하는 코드를 작성하는 것은 불가능함  
  - torch 생태계는 매우 비대해지고 버그가 많아짐  
  - 공유 추상화가 없기 때문에 다른 "솔루션"과 인터페이스하도록 설계되지 않은 새로운 라이브러리와 프레임워크가 등장함  
  - 이는 곧 종속성과 requirements.txt의 혼란으로 빠르게 변질됨  
  - GitHub 이슈나 포럼 토론의 70-80%는 단순히 서로 다른 라이브러리에서 오류가 발생하기 때문임  
  - 이를 해결할 방법은 거의 없음  
- 해결책의 부재  
  - 이는 OOP와 설계 문제임  
  - PyTree와 같은 기본적이고 PyTorch스러운 객체가 추상화를 위한 공통 기반을 구축하는 데 도움이 되었을 것으로 생각됨  
  - 함수형 프로그래밍 패러다임을 채택할 수도 없음  
  - 그렇게 하면 JAX의 성능이 떨어지는 버전으로 수렴하면서 모든 기존 torch 코드베이스의 이전 버전과의 호환성이 깨질 것임  
  - PyTorch는 이 부분에서 완전히 망가진 상태로 보임  
  
### JAX의 재현성 우위  
- 시드 처리  
  - PyTorch의 시드 처리는 이상적이지 않음  
  - 일반적으로 여러 줄의 코드를 실행해야 함  
  - 쉽게 잊어버리거나 잘못 구성할 수 있음  
  - JAX는 명시적인 키를 만들어 무작위성이 필요한 모든 함수에 전달하도록 강제함  
  - 이 접근 방식은 RNG가 항상 정적으로 시드되기 때문에 문제를 완전히 제거함  
  - JAX에는 자체 버전의 NumPy(jax.numpy)가 있으므로 별도로 시드를 설정할 필요가 없음  
  - 이러한 작은 QoL 결정은 전체 프레임워크의 사용자 경험을 훨씬 더 좋게 만들 수 있음  
- 이식성  
  - PyTorch 코드베이스를 사용할 때 가장 큰 문제 중 하나는 이식성 부족  
  - CUDA/GPU용으로 작성된 코드베이스는 TPU, NPU, AMD GPU 등의 비 Nvidia 하드웨어에서 실행될 때 잘 작동하지 않음  
  - 1개 노드용으로 작성된 PyTorch 코드를 다중 노드로 포팅하기 어려움  
  - 다중 노드는 종종 수십 시간의 개발 시간과 상당한 코드 변경이 필요함  
  - JAX의 컴파일러 중심 접근 방식은 이 부분에서 이점이 있음  
  - XLA는 장치 백엔드 간 전환을 처리하며 최소한의 코드 변경으로 GPU/TPU/다중 노드/다중 슬라이스에서 잘 작동함  
  - 하드웨어 공급업체가 장치를 지원하기 쉽고 장치 간 전환을 더 쉽게 만듦  
  - 모든 사람이 동일한 하드웨어에 액세스할 수 있는 것은 아니므로 다양한 유형의 하드웨어에서 이식 가능한 코드베이스는 딥러닝을 초보자/중급자에게 더 접근하기 쉽게 만드는 작은 단계가 될 수 있음  
- 자동 스케일링  
  - 자체적으로 잘 자동 스케일링할 수 있는 코드베이스는 재현에 매우 도움이 됨  
  - 이상적인 경우 최소한의 코드 변경으로 네트워킹 경계에 구애받지 않고 자동으로 발생해야 함  
  - JAX는 이를 잘 수행함  
  - JAX 코드를 작성할 때 통신 기본 요소를 지정하거나 torch.distributed.barrier()를 모든 곳에 배치할 필요가 없음   
  - XLA는 사용 가능한 하드웨어를 고려하여 자동으로 이를 삽입함  
  - JAX가 감지할 수 있는 모든 장치는 네트워킹, 토폴로지, 구성 등에 관계없이 자동으로 사용됨  
  - 계산을 자동으로 동기화 및 준비하고 최적화 패스를 적용하여 커널의 비동기 실행을 최대화하고 대기 시간을 최소화함  
  - 사람이 해야 할 일은 입력 배열의 배치 차원과 같이 장치에 분산시키려는 텐서의 샤딩을 지정하는 것뿐임  
  - XLA의 "계산은 샤딩을 따른다"는 접근 방식 때문에 자동으로 나머지를 파악함  
  - 규모에 맞게 검증된 실험을 취미로 쉽게 실행하여 실험하고 잠재적으로 반복할 수 있음  
  - 이는 잊혀진 아이디어의 발견을 더 쉽게 하고, 최소한의 노력으로 더 큰 규모에서 함수로 쉽게 테스트할 수 있으므로 그러한 실험을 장려할 수 있음  
  
### JAX의 단점  
  
- 거버넌스 구조  
  - 현재 XLA는 TensorFlow 거버넌스 하에 있음  
  - PyTorch와 유사한 별도의 조직 기구를 설립하는 것에 대한 논의가 있었지만, 구체적인 노력은 많이 이루어지지 않음  
  - Google이 인기 없는 제품을 중단하는 평판 때문에 Google에 대한 신뢰도가 높지 않음  
  - JAX는 기술적으로 DeepMind 프로젝트이며 Google의 전체 AI 추진에 핵심적인 의미가 있지만, 생태계 전체에 장기적으로 큰 이점이 될 것으로 보임  
  - 별도의 거버넌스 기구가 프로젝트 개발에 지침을 제공할 것임  
  - 이는 구체적인 구조를 제공하고 Google의 악명 높은 관료주의와 분리되어 한 번에 많은 문제를 피할 수 있음  
  - JAX가 반드시 이러한 종류의 공식 구조를 필요로 하는 것은 아니지만, Google 상위 경영진의 결정에 관계없이 JAX 개발이 오랫동안 이루어질 것이라는 보장이 있으면 좋을 것임  
  - 이는 언젠가는 유지 관리되지 않을 수 있는 도구를 통합하는 데 리소스를 투입하는 것을 주저하는 기업과 대형 연구소에서 채택하는 데 분명히 도움이 될 것임  
- XLA의 오픈 소스 전환  
  - 오랜 시간 동안 XLA는 폐쇄 소스 프로젝트였음  
  - 그러나 이를 오픈 소스로 만들기 위한 노력이 이루어졌고, 현재 OpenXLA는 내부 XLA 빌드보다 훨씬 우수한 성능을 보여줌  
  - 하지만 XLA의 내부에 대한 문서는 여전히 부족함  
  - 대부분의 리소스는 라이브 토크와 가끔 논문일 뿐이며, 종종 오래되었음  
  - 예정된 기능에 대한 공개적으로 접근 가능한 로드맵이 있으면 사람들이 진행 상황을 추적하고 특히 흥미로운 것에 기여하기 쉬울 것임  
  - XLA 컴파일러 스택의 각 단계를 분석하고 세부 사항을 설명하는 Edward Yang 스타일의 미니 블로그 게시물을 통해 XLA가 무엇을 할 수 있고 할 수 없는지 실무자들이 더 잘 평가할 수 있는 방법을 제공하는 것이 좋을 것임  
  - 이는 리소스 집약적이며 다른 곳으로 더 잘 전달될 수 있다는 것을 이해하지만, 사람들은 도구를 이해할 때 더 신뢰하며, 전체 생태계에 걸쳐 긍정적인 파급 효과가 있어 모두에게 이익이 된다고 생각함  
- 생태계 통합  
  - flax는 JAX 생태계의 골칫거리임  
  - 직관적이지 않은 API, 간결한 구문을 가지고 있으며 PyTorch에서 전환하는 초보자에게는 절대적인 지옥임  
  - equinox를 사용하는 것이 좋음  
  - flax의 단점을 해결하기 위한 개발팀의 시도가 있었지만, 궁극적으로는 시간 낭비임  
  - equinox 스타일의 API를 원한다면 equinox를 사용하는 것이 좋음  
  - flax가 특별히 더 잘하는 것이 많지 않으며 equinox로 복제하기 어렵지 않음  
  - 현재 JAX 생태계의 많은 부분이 flax를 중심으로 설계되어 있음  
  - equinox는 근본적으로 PyTree와 인터페이스하기 때문에 모든 라이브러리와 상호 호환되지만, 약간의 eqx.partition과 filter가 필요함   
  - 상태 quo를 바꾸고 싶음. equinox가 모든 곳에서 일류 지원을 받아야 함  
  - 이는 논란의 여지가 있는 의견이지만, 이는 고전적인 매몰 비용 오류임  
  - equinox는 JAX 프레임워크가 항상 그래야 했던 방식으로 더 잘 작동함  
  - equinox 문서에 요약된 대로 equinox와 flax를 비교해 보면 equinox가 더 나음  
  - JAX 생태계 관리자들이 equinox의 인기를 인식하고 그에 따라 조정하는 것은 좋은 일이지만, Google과 flax 팀에서도 공식적으로 더 많은 사랑을 보여주기를 바람  
  - JAX를 시도해 보고 싶다면 equinox를 사용하는 것이 좋음  
- 날카로운 모서리  
  - API 설계 결정과 XLA 제한으로 인해 JAX에는 주의해야 할 "날카로운 모서리"가 있음  
  - 잘 작성된 문서에 이에 대해 매우 간결하게 설명되어 있음  
  - JAX를 사용하기 전에 적어도 한 번은 읽어보는 것이 좋음  
  - RTFM을 하는 것이 항상 그렇듯이 많은 시간과 에너지를 절약해 줄 것임  
  
### 결론  
- 이 블로그 게시물은 PyTorch가 실제 연구 워크로드, 특히 GPU에 가장 적합하다는 흔히 반복되는 Myth를 바로잡기 위한 것이었음. **더 이상 그렇지 않음**  
- 사실 모든 PyTorch 코드를 JAX로 포팅하는 것이 분야 전체에 엄청나게 유익할 것이라고 주장할 만큼 극단적임  
  - 자동 병렬화, 재현성, 깨끗한 함수형 API 등은 사소한 기능이 아니며 많은 연구 코드베이스에 큰 도움이 될 것임  
- 이 분야를 조금이라도 더 좋게 만들고 싶다면 코드베이스를 JAX로 다시 작성하는 것을 고려해 보세요

## Comments



### Comment 28265

- Author: xguru
- Created: 2024-08-25T10:00:34+09:00
- Points: 1

세상은 계속 흘러갑니다. ㅎㅎ  
  
[2022년 PyTorch 와 TensorFlow 비교](https://news.hada.io/topic?id=5578)

### Comment 28189

- Author: hilft
- Created: 2024-08-21T23:29:08+09:00
- Points: 1

torch랑 onnx로 버티겠습니다

### Comment 28183

- Author: flrngel
- Created: 2024-08-21T13:30:26+09:00
- Points: 1

학부생이 쓴 글.. ㄷㄷ

### Comment 28173

- Author: cosine20
- Created: 2024-08-21T10:00:53+09:00
- Points: 1

PyTorch는 Huggingface 없었으면 진짜 ㅋㅋ

### Comment 28112

- Author: lemonmint
- Created: 2024-08-19T17:01:47+09:00
- Points: 1

JAX 만세! 최근에 써봤는데 NNX API가 매우 마음에 들었습니다.

### Comment 28106

- Author: stareta1202
- Created: 2024-08-19T13:42:02+09:00
- Points: 5

JAX의 가장 큰 문제는 구글이 라는 점임. 구글은 오픈소스를 버리기로 굉장히 유명(Tflite, android things, dart, angular, bazel 등등) tensorflow도 어느 순간부터 업데이트가 잘 안되기 시작, 반면 torch는 방대한 오픈소스를 운영하는 facebook에서 시작해서 굉장히 잘 운영 및 이미 torch 재단에서 운영중임. torch의 단점은 분명히 맞는 부분이 있지만, 해당 오픈소스를 누가 지속가능하게 운영하는가 에 있어서 JAX는 이미 큰 위험을 가지고 시작하는 거 같음

### Comment 28161

- Author: dalinaum
- Created: 2024-08-20T22:29:52+09:00
- Points: 1
- Parent comment: 28106
- Depth: 1

최소한 Dart는 플러터로 한동안은 잘 살아있을 것 같군요.

### Comment 28144

- Author: ilotoki0804
- Created: 2024-08-20T10:06:52+09:00
- Points: 3
- Parent comment: 28106
- Depth: 1

페이스북은 리엑트, Django 등 그래도 자신들이 사용하는 기술 스택에 대해 의리 있게(?) 지속적으로 기여하는 것 같은데 구글은 조금만 퇴물되도 헌신짝처럼 버리는 것 같아요...
