14P by xguru 31일전 | favorite | 댓글 8개
  • PyTorch가 생산성 손실과 개발 시간 낭비를 초래하는 이유는 "프레임워크 자체가 나쁘기 때문이 아니라, 현재 적용되는 유스케이스에 맞게 설계되지 않았기 때문"

PyTorch의 철학

  • PyTorch의 철학은 동적이고, 디버깅하기 쉽고, 파이썬스러움
  • 반면 TensorFlow 1.x는 XLA 컴파일러를 강력히 사용해 정적이지만 성능이 좋은 프레임워크가 되려고 함
  • TensorFlow 개발자들은 커뮤니티가 1.x API를 싫어한다는 것을 깨닫고 Keras를 메인 인터페이스로 사용하기로 결정하고 XLA 컴파일러의 역할을 축소함
  • PyTorch는 뿌리를 지켰고, TensorFlow의 정적이고 지연된 접근법과 달리 torch.Tensor가 즉시 평가되는 더 역동적인 "즉시 실행" 접근법을 채택함
  • 이게 성과를 내면서 많은 연구가 PyTorch로 옮겨감
  • 2021년 GPT-3가 등장하면서 성능과 확장성이 주요 관심사가 됨
  • PyTorch는 이러한 수요에 어느 정도 잘 대응했지만, 이러한 철학을 염두에 두고 설계되지 않았기 때문에 점점 부채가 쌓이고 기반이 흔들리기 시작함
  • PyTorch 개발자들은 어떤 타협점도 원하지 않았고 두 가지 경로를 동시에 추구하기로 선택함
    • XLA 컴파일러를 성능과 안정성이 뛰어난 기본 백엔드로 사용
    • torch.compile 스택을 구축하여 필요한 경우 사용자가 컴파일러를 호출할 수 있는 자유를 부여
  • 장기 전략의 부재는 심각한 문제임
  • PyTorch는 컴파일러 중심의 철학(JAX와 같은)에 전념하고 싶어 하지 않지만 좋은 대안이 보이지 않음
  • 이 문제에 대한 경쟁 제품들의 해결책은 ?

JAX의 컴파일러 기반 개발

  • JAX는 TensorFlow의 강력한 컴파일러 스택인 XLA를 활용함
  • XLA는 강력한 컴파일러이지만, 엔드 유저에게는 모두 추상화되어 있음
  • 함수가 순수(pure)하기만 하면 @jax.jit 데코레이터를 사용해 함수를 JIT 컴파일하고 XLA에서 사용할 수 있게 만들 수 있음
  • XLA는 생성된 그래프가 정확한지 검증하고, JAX에서 샤딩을 사용한 자동 병렬화를 처리하는 GSPMD 파티셔너, 그래프 최적화, 연산자 및 커널 융합, 대기 시간 숨김 스케줄링, 비동기 통신 오버랩, triton과 같은 다른 백엔드에 대한 코드 생성 등을 모두 뒤에서 처리함
  • JAX 제한 사항을 준수하기만 하면 XLA가 자동으로 처리해 줌
  • 예를 들어 병렬화할 때 torch.distributed.barrier()와 같은 통신 프리미티브가 필요하지 않음
  • DDP 지원은 간단한 코드로 가능함
  • XLA의 접근 방식은 계산이 샤딩을 따른다는 것임. 따라서 입력 배열이 어떤 축을 따라 샤딩되면 XLA는 하위 계산에 대해 자동으로 처리함
  • "컴파일러 기반 개발" 아이디어는 Rust 컴파일러의 작동 방식과 유사함
  • PyTorch의 한계
    • PyTorch 개발자들이 유연성과 자유의 핵심 철학을 유지하는 대신 새로운 기능을 위해 컴파일러 스택을 통합하고 의존하기로 한 선택에 불만족스러움
    • PyTorch 2.x의 공식 로드맵에 따르면 XLA를 Torch와 완전히 통합할 장기 계획을 명확히 제시하고 있음
    • 이는 끔찍한 아이디어임. Rust 컴파일러에 C++ 코드를 억지로 끼워 넣는 것이 Rust 자체를 사용하는 것보다 더 나은 경험이 될 것이라고 말하는 것과 같음
    • Torch는 JAX와 달리 XLA를 중심으로 설계되지 않았음
    • PyTorch가 XLA 기반 컴파일러 스택을 사용하기로 결정한다면, 이상적인 프레임워크는 그것을 중심으로 특별히 설계되고 구축된 것이 아닐까?
    • PyTorch가 원하는 컴파일러 백엔드를 선택할 수 있는 "멀티 백엔드" 접근 방식을 추구하더라도 조각화 문제를 악화시키고 모든 컴파일러 스택의 제한을 존중하려고 시도하면서 API를 절대적으로 망가뜨리지 않을까?
    • Torch/XLA를 TPU에서 사용해 본 사람은 누구나 심각한 PTSD로 고통받음

Multi-Backend는 망했음

  • PyTorch는 한 번에 모든 것을 하려고 하면서 비참하게 실패함
  • "멀티 백엔드" 설계 결정은 이 문제를 기하급수적으로 악화시킴
  • 이론적으로는 원하는 스택을 선택할 수 있는 것처럼 들리지만, 실제로는 이해하기 어려운 트레이스백과 비호환성 문제의 엉킨 혼란임
  • 백엔드 간 제약 조건과 PyTorch API의 충돌
    • 이러한 백엔드를 작동시키는 것 자체가 어려운 것이 아니라, 이 백엔드들이 기대하는 제약 조건이 PyTorch의 유연하고 Pythonic한 API와 잘 맞지 않음
    • API의 일관성을 유지하는 것과 백엔드의 제한을 따르는 것 사이에는 트레이드오프가 있음
    • 결과적으로 개발자들은 단일 백엔드와 실제로 통합/커밋하는 대신 코드 생성에 더 의존하려고 함
  • PyTorch의 전략 부재
    • PyTorch는 의미 있는 트레이드오프를 거부하기 때문에 모든 결정이 타협처럼 느껴짐
    • 일관성도, 전반적인 전략도 없음
    • 궁극적으로 사용자에게 많은 좌절감을 야기하고 잘 어울리지 않는 기능들의 잡동사니처럼 느껴짐
    • 생태계를 죽이는 더 빠른 방법은 없음
  • JAX 접근 방식을 따라서는 안 되는 이유
    • PyTorch는 JAX의 "통합 컴파일러 및 백엔드" 접근 방식을 따라서는 안 됨
    • JAX는 XLA와 함께 작동하도록 명시적으로 설계되었기 때문
    • PyTorch 프론트엔드를 JAX의 것으로 교체하는 것이 전략이 될 수는 없음
    • XLA를 기반으로 JAX보다 더 나은 API를 고안하는 것은 사실상 불가능함
    • 개발자들이 새롭고 다른 아이디어를 시도하는 것을 비난하지는 않음
    • 그러나 PyTorch가 시간의 시험을 견디려면, 이상적인 튜토리얼 조건 밖에서 즉시 무너지는 멋진 새 기능을 제공하는 것보다 기반을 강화하는 데 더 중점을 두어야 함

PyTorch의 파편화와 JAX의 함수형 프로그래밍

  • JAX의 함수형 API
    • JAX 함수는 순수(pure)해야 함. 즉, 전역적인 부작용이 없어야 함
    • 수학 함수처럼 동일한 데이터가 주어지면 실행 컨텍스트에 상관없이 항상 동일한 출력을 반환해야 함
    • 이러한 설계 철학 덕분에 JAX 함수는 구성 가능하고 서로 잘 상호 운용됨
    • 개발 복잡성이 줄어들고, 함수는 특정 시그니처와 잘 정의된 구체적인 작업으로 정의됨
    • 타입이 지켜지면 함수는 즉시 작동할 것이 보장됨
    • 이는 과학 계산, 특히 딥러닝에서 필요한 작업 유형에 적합함
  • optax API 예시
    • 함수형 접근 방식 덕분에 optax에는 "체인"이라는 것이 있음
    • 이는 그래디언트에 순차적으로 적용되는 여러 함수를 포함함
    • 근본적인 구성 요소는 GradientTransformation임
    • 강력하면서도 표현력 있는 API를 만듦
    • 예를 들어 그래디언트를 클리핑하거나, 그래디언트의 EMA를 취하거나, 옵티마이저를 결합하는 등의 작업이 매우 간단해짐
  • 함수형 설계의 장점
    • 함수형 설계의 또 다른 멋진 결과는 vmap임
    • 이는 'vectorized' map을 의미하며 정확히 그 기능을 설명함
    • 모든 것을 map할 수 있고, vmap이기만 하면 XLA가 자동으로 융합하고 최적화함
    • 함수를 작성할 때 배치 차원을 생각할 필요가 없음
    • 모든 코드를 vmap하기만 하면 됨
    • 이는 ein-* 작업이 덜 필요하다는 것을 의미함
    • 2D/3D 텐서 조작을 파악하는 것이 더 직관적이고 가독성도 훨씬 좋음
    • 개별 구성 요소를 격리하여 추론하기만 하면 되므로 잘 작동하는 복잡한 코드를 더 쉽게 작성할 수 있음
    • 순수성 제약 조건을 존중하고 올바른 시그니처만 있으면 구성 가능성과 같은 다른 모든 이점을 누릴 수 있음
  • PyTorch 생태계의 문제점
    • torch에서는 사용하는 스택(FSDP + 다중 노드 + torch.compile 등)에 관계없이 항상 무언가 깨질 가능성이 있음
    • 여러 가지가 올바르게 함께 작동해야 하며, 어떤 구성 요소라도 실패하면 오전 3시까지 디버깅해야 함
    • PyTorch가 제공하는 수십 가지 기능의 모든 조합을 테스트할 수 없기 때문에 개발 중에 발견되지 않은 버그가 항상 있을 것임
    • 상당한 노력 없이는 잘 작동하는 코드를 작성하는 것은 불가능함
    • torch 생태계는 매우 비대해지고 버그가 많아짐
    • 공유 추상화가 없기 때문에 다른 "솔루션"과 인터페이스하도록 설계되지 않은 새로운 라이브러리와 프레임워크가 등장함
    • 이는 곧 종속성과 requirements.txt의 혼란으로 빠르게 변질됨
    • GitHub 이슈나 포럼 토론의 70-80%는 단순히 서로 다른 라이브러리에서 오류가 발생하기 때문임
    • 이를 해결할 방법은 거의 없음
  • 해결책의 부재
    • 이는 OOP와 설계 문제임
    • PyTree와 같은 기본적이고 PyTorch스러운 객체가 추상화를 위한 공통 기반을 구축하는 데 도움이 되었을 것으로 생각됨
    • 함수형 프로그래밍 패러다임을 채택할 수도 없음
    • 그렇게 하면 JAX의 성능이 떨어지는 버전으로 수렴하면서 모든 기존 torch 코드베이스의 이전 버전과의 호환성이 깨질 것임
    • PyTorch는 이 부분에서 완전히 망가진 상태로 보임

JAX의 재현성 우위

  • 시드 처리
    • PyTorch의 시드 처리는 이상적이지 않음
    • 일반적으로 여러 줄의 코드를 실행해야 함
    • 쉽게 잊어버리거나 잘못 구성할 수 있음
    • JAX는 명시적인 키를 만들어 무작위성이 필요한 모든 함수에 전달하도록 강제함
    • 이 접근 방식은 RNG가 항상 정적으로 시드되기 때문에 문제를 완전히 제거함
    • JAX에는 자체 버전의 NumPy(jax.numpy)가 있으므로 별도로 시드를 설정할 필요가 없음
    • 이러한 작은 QoL 결정은 전체 프레임워크의 사용자 경험을 훨씬 더 좋게 만들 수 있음
  • 이식성
    • PyTorch 코드베이스를 사용할 때 가장 큰 문제 중 하나는 이식성 부족
    • CUDA/GPU용으로 작성된 코드베이스는 TPU, NPU, AMD GPU 등의 비 Nvidia 하드웨어에서 실행될 때 잘 작동하지 않음
    • 1개 노드용으로 작성된 PyTorch 코드를 다중 노드로 포팅하기 어려움
    • 다중 노드는 종종 수십 시간의 개발 시간과 상당한 코드 변경이 필요함
    • JAX의 컴파일러 중심 접근 방식은 이 부분에서 이점이 있음
    • XLA는 장치 백엔드 간 전환을 처리하며 최소한의 코드 변경으로 GPU/TPU/다중 노드/다중 슬라이스에서 잘 작동함
    • 하드웨어 공급업체가 장치를 지원하기 쉽고 장치 간 전환을 더 쉽게 만듦
    • 모든 사람이 동일한 하드웨어에 액세스할 수 있는 것은 아니므로 다양한 유형의 하드웨어에서 이식 가능한 코드베이스는 딥러닝을 초보자/중급자에게 더 접근하기 쉽게 만드는 작은 단계가 될 수 있음
  • 자동 스케일링
    • 자체적으로 잘 자동 스케일링할 수 있는 코드베이스는 재현에 매우 도움이 됨
    • 이상적인 경우 최소한의 코드 변경으로 네트워킹 경계에 구애받지 않고 자동으로 발생해야 함
    • JAX는 이를 잘 수행함
    • JAX 코드를 작성할 때 통신 기본 요소를 지정하거나 torch.distributed.barrier()를 모든 곳에 배치할 필요가 없음
    • XLA는 사용 가능한 하드웨어를 고려하여 자동으로 이를 삽입함
    • JAX가 감지할 수 있는 모든 장치는 네트워킹, 토폴로지, 구성 등에 관계없이 자동으로 사용됨
    • 계산을 자동으로 동기화 및 준비하고 최적화 패스를 적용하여 커널의 비동기 실행을 최대화하고 대기 시간을 최소화함
    • 사람이 해야 할 일은 입력 배열의 배치 차원과 같이 장치에 분산시키려는 텐서의 샤딩을 지정하는 것뿐임
    • XLA의 "계산은 샤딩을 따른다"는 접근 방식 때문에 자동으로 나머지를 파악함
    • 규모에 맞게 검증된 실험을 취미로 쉽게 실행하여 실험하고 잠재적으로 반복할 수 있음
    • 이는 잊혀진 아이디어의 발견을 더 쉽게 하고, 최소한의 노력으로 더 큰 규모에서 함수로 쉽게 테스트할 수 있으므로 그러한 실험을 장려할 수 있음

JAX의 단점

  • 거버넌스 구조
    • 현재 XLA는 TensorFlow 거버넌스 하에 있음
    • PyTorch와 유사한 별도의 조직 기구를 설립하는 것에 대한 논의가 있었지만, 구체적인 노력은 많이 이루어지지 않음
    • Google이 인기 없는 제품을 중단하는 평판 때문에 Google에 대한 신뢰도가 높지 않음
    • JAX는 기술적으로 DeepMind 프로젝트이며 Google의 전체 AI 추진에 핵심적인 의미가 있지만, 생태계 전체에 장기적으로 큰 이점이 될 것으로 보임
    • 별도의 거버넌스 기구가 프로젝트 개발에 지침을 제공할 것임
    • 이는 구체적인 구조를 제공하고 Google의 악명 높은 관료주의와 분리되어 한 번에 많은 문제를 피할 수 있음
    • JAX가 반드시 이러한 종류의 공식 구조를 필요로 하는 것은 아니지만, Google 상위 경영진의 결정에 관계없이 JAX 개발이 오랫동안 이루어질 것이라는 보장이 있으면 좋을 것임
    • 이는 언젠가는 유지 관리되지 않을 수 있는 도구를 통합하는 데 리소스를 투입하는 것을 주저하는 기업과 대형 연구소에서 채택하는 데 분명히 도움이 될 것임
  • XLA의 오픈 소스 전환
    • 오랜 시간 동안 XLA는 폐쇄 소스 프로젝트였음
    • 그러나 이를 오픈 소스로 만들기 위한 노력이 이루어졌고, 현재 OpenXLA는 내부 XLA 빌드보다 훨씬 우수한 성능을 보여줌
    • 하지만 XLA의 내부에 대한 문서는 여전히 부족함
    • 대부분의 리소스는 라이브 토크와 가끔 논문일 뿐이며, 종종 오래되었음
    • 예정된 기능에 대한 공개적으로 접근 가능한 로드맵이 있으면 사람들이 진행 상황을 추적하고 특히 흥미로운 것에 기여하기 쉬울 것임
    • XLA 컴파일러 스택의 각 단계를 분석하고 세부 사항을 설명하는 Edward Yang 스타일의 미니 블로그 게시물을 통해 XLA가 무엇을 할 수 있고 할 수 없는지 실무자들이 더 잘 평가할 수 있는 방법을 제공하는 것이 좋을 것임
    • 이는 리소스 집약적이며 다른 곳으로 더 잘 전달될 수 있다는 것을 이해하지만, 사람들은 도구를 이해할 때 더 신뢰하며, 전체 생태계에 걸쳐 긍정적인 파급 효과가 있어 모두에게 이익이 된다고 생각함
  • 생태계 통합
    • flax는 JAX 생태계의 골칫거리임
    • 직관적이지 않은 API, 간결한 구문을 가지고 있으며 PyTorch에서 전환하는 초보자에게는 절대적인 지옥임
    • equinox를 사용하는 것이 좋음
    • flax의 단점을 해결하기 위한 개발팀의 시도가 있었지만, 궁극적으로는 시간 낭비임
    • equinox 스타일의 API를 원한다면 equinox를 사용하는 것이 좋음
    • flax가 특별히 더 잘하는 것이 많지 않으며 equinox로 복제하기 어렵지 않음
    • 현재 JAX 생태계의 많은 부분이 flax를 중심으로 설계되어 있음
    • equinox는 근본적으로 PyTree와 인터페이스하기 때문에 모든 라이브러리와 상호 호환되지만, 약간의 eqx.partition과 filter가 필요함
    • 상태 quo를 바꾸고 싶음. equinox가 모든 곳에서 일류 지원을 받아야 함
    • 이는 논란의 여지가 있는 의견이지만, 이는 고전적인 매몰 비용 오류임
    • equinox는 JAX 프레임워크가 항상 그래야 했던 방식으로 더 잘 작동함
    • equinox 문서에 요약된 대로 equinox와 flax를 비교해 보면 equinox가 더 나음
    • JAX 생태계 관리자들이 equinox의 인기를 인식하고 그에 따라 조정하는 것은 좋은 일이지만, Google과 flax 팀에서도 공식적으로 더 많은 사랑을 보여주기를 바람
    • JAX를 시도해 보고 싶다면 equinox를 사용하는 것이 좋음
  • 날카로운 모서리
    • API 설계 결정과 XLA 제한으로 인해 JAX에는 주의해야 할 "날카로운 모서리"가 있음
    • 잘 작성된 문서에 이에 대해 매우 간결하게 설명되어 있음
    • JAX를 사용하기 전에 적어도 한 번은 읽어보는 것이 좋음
    • RTFM을 하는 것이 항상 그렇듯이 많은 시간과 에너지를 절약해 줄 것임

결론

  • 이 블로그 게시물은 PyTorch가 실제 연구 워크로드, 특히 GPU에 가장 적합하다는 흔히 반복되는 Myth를 바로잡기 위한 것이었음. 더 이상 그렇지 않음
  • 사실 모든 PyTorch 코드를 JAX로 포팅하는 것이 분야 전체에 엄청나게 유익할 것이라고 주장할 만큼 극단적임
    • 자동 병렬화, 재현성, 깨끗한 함수형 API 등은 사소한 기능이 아니며 많은 연구 코드베이스에 큰 도움이 될 것임
  • 이 분야를 조금이라도 더 좋게 만들고 싶다면 코드베이스를 JAX로 다시 작성하는 것을 고려해 보세요

JAX의 가장 큰 문제는 구글이 라는 점임. 구글은 오픈소스를 버리기로 굉장히 유명(Tflite, android things, dart, angular, bazel 등등) tensorflow도 어느 순간부터 업데이트가 잘 안되기 시작, 반면 torch는 방대한 오픈소스를 운영하는 facebook에서 시작해서 굉장히 잘 운영 및 이미 torch 재단에서 운영중임. torch의 단점은 분명히 맞는 부분이 있지만, 해당 오픈소스를 누가 지속가능하게 운영하는가 에 있어서 JAX는 이미 큰 위험을 가지고 시작하는 거 같음

페이스북은 리엑트, Django 등 그래도 자신들이 사용하는 기술 스택에 대해 의리 있게(?) 지속적으로 기여하는 것 같은데 구글은 조금만 퇴물되도 헌신짝처럼 버리는 것 같아요...

최소한 Dart는 플러터로 한동안은 잘 살아있을 것 같군요.

세상은 계속 흘러갑니다. ㅎㅎ

2022년 PyTorch 와 TensorFlow 비교

torch랑 onnx로 버티겠습니다

학부생이 쓴 글.. ㄷㄷ

PyTorch는 Huggingface 없었으면 진짜 ㅋㅋ

JAX 만세! 최근에 써봤는데 NNX API가 매우 마음에 들었습니다.