28.36 PyTorch, TensorFlow, JAX에서의 텐서 자료 구조 비교

1. 비교의 의의

PyTorch, TensorFlow, JAX는 현재 가장 널리 사용되는 세 가지 대표적인 딥러닝 프레임워크이다. 이 세 시스템의 핵심 자료 구조는 모두 다차원 배열로서의 텐서이지만, 그 내부 구현, 자동 미분 결합 방식, 실행 모형, 변경 가능성(Mutability), 디바이스 추상화 등에서 본질적인 설계상의 차이를 가진다. 이러한 차이는 단순한 구문상의 차이가 아니라, 각 프레임워크가 채택하고 있는 프로그래밍 패러다임과 직접 연결되어 있다. 이 절에서는 세 프레임워크의 텐서 자료 구조를 학술적 관점에서 비교 분석한다.

2. PyTorch의 텐서 자료 구조

2.1 핵심 클래스 torch.Tensor

PyTorch의 텐서는 torch.Tensor 클래스로 구현되며, 그 내부는 두 가지 핵심 구성 요소로 분해된다. 첫째는 실제 원소가 저장되는 연속적 메모리 영역인 저장소(Storage, torch.Storage)이며, 둘째는 그 저장소를 다차원 배열로 해석하기 위한 메타데이터(형상, 스트라이드, 자료형, 디바이스, 오프셋)이다. 이러한 분리는 동일한 저장소를 공유하면서도 서로 다른 형상이나 스트라이드를 가지는 여러 텐서를 형성할 수 있게 하며, view, reshape, transpose, 슬라이싱 등의 연산이 데이터 복사 없이 이루어질 수 있는 기반이 된다.

2.2 변경 가능성과 즉시 실행

PyTorch의 텐서는 변경 가능(Mutable)한 객체이며, 인플레이스(In-place) 연산(예: add_, mul_)을 통해 자기 자신의 내용을 직접 갱신할 수 있다. 또한 PyTorch는 처음부터 즉시 실행(Eager Execution) 방식을 채택하였으므로, 사용자가 작성한 텐서 연산은 정의되는 순간 즉시 수행되어 결과 텐서를 산출한다. 이러한 설계는 일반적인 파이썬 디버깅 환경과 자연스럽게 결합되며, 모델의 동작을 단계별로 검사하기에 용이하다는 장점을 가진다.

2.3 자동 미분과 동적 계산 그래프

PyTorch의 자동 미분 시스템 autograd는 동적 계산 그래프(Dynamic Computational Graph) 또는 “Define-by-Run” 방식으로 동작한다. 학습 가능한 텐서는 requires_grad=True 속성을 가지며, 이러한 텐서를 입력으로 하는 연산은 전향 단계에서 자동으로 계산 그래프를 즉석에서 구성한다. 후행 단계에서 .backward()가 호출되면 그래프가 역방향으로 순회되며, 각 잎 텐서의 .grad 속성에 경사도가 누적된다. 이러한 동적 그래프는 데이터 의존적인 제어 흐름을 매우 자연스럽게 표현할 수 있게 한다.

2.4 디바이스 추상화

PyTorch의 텐서는 device 속성을 통해 자신이 위치한 하드웨어를 명시한다. CPU 메모리, NVIDIA GPU(CUDA), Apple Metal(MPS), 그리고 최근에는 다양한 가속기 백엔드를 지원하며, .to(device) 호출을 통해 다른 디바이스로 이동할 수 있다. 디바이스 사이의 연산은 명시적으로 이동된 텐서들 사이에서만 허용되며, 이는 사용자가 데이터의 위치를 항상 명확히 인식하도록 강제한다.

3. TensorFlow의 텐서 자료 구조

3.1 tf.Tensortf.Variable

TensorFlow의 텐서는 두 가지 주요 형태로 구분된다. 첫째는 tf.Tensor로, 일반적으로 불변(Immutable) 객체로 설계되어 한 번 생성되면 그 내용을 직접 수정할 수 없다. 둘째는 tf.Variable로, 학습 가능한 매개변수의 표현을 위해 명시적으로 도입된 변경 가능 컨테이너이다. 변수는 내부에 텐서를 보유하며, assign, assign_add, assign_sub 등의 메서드를 통해 그 내용을 갱신한다. 이러한 분리는 함수형 패러다임과 명령형 패러다임 사이의 절충을 명시적으로 표현한 것이다.

3.2 정적 그래프에서 즉시 실행으로의 전환

TensorFlow 1.x는 사용자가 먼저 그래프를 구축하고, 이후 Session 객체를 통해 입력 데이터를 흘려보내며 그래프를 실행하는 정적 그래프(Static Graph) 패러다임을 채택하였다. 그러나 사용 편의성과 디버깅 용이성에서 동적 그래프 방식이 가지는 장점이 두드러지면서, TensorFlow 2.x는 즉시 실행을 기본 동작으로 채택하였다. 동시에 정적 그래프의 최적화 이점을 유지하기 위하여 tf.function 데코레이터가 도입되었으며, 이는 파이썬 함수를 추적(Trace)하여 사후적으로 정적 그래프로 변환한 뒤 컴파일하여 실행한다.

3.3 자동 미분과 GradientTape

TensorFlow 2.x의 자동 미분 시스템은 tf.GradientTape라는 컨텍스트 관리자를 중심으로 동작한다. 사용자가 with tf.GradientTape() as tape: 블록 내에서 수행한 모든 텐서 연산은 자동으로 기록되며, 블록을 빠져나온 뒤 tape.gradient(loss, variables)를 호출하여 임의의 변수에 대한 손실의 경사도를 산출한다. 이 방식은 PyTorch의 동적 그래프와 유사하지만, 기록의 시작과 종료를 명시적인 컨텍스트로 제한한다는 점에서 약간 더 명시적인 형태를 띤다.

3.4 XLA 백엔드와 컴파일

TensorFlow는 텐서 연산 그래프를 가속 선형대수(Accelerated Linear Algebra, XLA) 컴파일러로 전달하여, 연산 융합(Operation Fusion), 메모리 재사용, 디바이스별 최적화된 커널 생성 등을 자동으로 수행할 수 있다. 이러한 컴파일은 tf.functionjit_compile=True 옵션이나 환경 설정을 통해 활성화되며, 정적 그래프 변환과 결합되어 추론 및 학습의 처리량을 결정적으로 향상시킨다.

4. JAX의 텐서 자료 구조

4.1 함수형 패러다임과 불변 텐서

JAX의 텐서, 즉 jax.Array는 본질적으로 NumPy의 ndarray와 동일한 API를 제공하지만, 결정적인 차이점은 완전한 불변성(Immutability)이다. 어떠한 텐서도 인플레이스 갱신을 허용하지 않으며, 기존 텐서의 일부를 변경하려면 항상 새로운 텐서를 산출하는 함수형 갱신 연산(예: x.at[i].set(v))을 사용해야 한다. 이러한 불변성은 JAX가 의도적으로 채택한 함수형 패러다임의 자연스러운 결과이며, 변환(Transformation) 기반 컴파일과 자동 미분의 수학적 명료성을 확보하기 위한 설계 결정이다.

4.2 변환으로서의 자동 미분

JAX의 자동 미분은 다른 프레임워크와 본질적으로 다른 방식으로 정의된다. 즉, 자동 미분은 텐서 객체의 메서드가 아니라 함수에 대한 변환(Transformation)이다. jax.grad는 임의의 순수 함수(Pure Function)를 입력으로 받아, 동일한 입력 시그니처를 가지지만 그 함수의 경사도를 반환하는 새로운 함수를 산출한다. 마찬가지로 jax.jacfwdjax.jacrev는 각각 전향 모드 야코비안과 역향 모드 야코비안을 계산하는 함수를, jax.hessian은 헤시안을 계산하는 함수를 산출한다. 이러한 함수 변환적 설계는 임의 차수의 미분과 자유로운 합성을 자연스럽게 지원한다.

4.3 변환의 합성과 추적

JAX의 또 다른 핵심 변환은 jax.jit(즉시 실행 컴파일), jax.vmap(자동 벡터화), jax.pmap(병렬 실행)이다. 이들 변환은 모두 함수를 입력으로 받아 함수를 산출하며, 자유롭게 합성될 수 있다. 예를 들어 jit(vmap(grad(f)))는 함수 f의 경사도를 자동 벡터화한 뒤 컴파일한 새로운 함수를 산출한다. 이러한 변환은 모두 추적(Tracing)이라는 공통 메커니즘에 기반하며, 추적 단계에서 함수의 호출은 실제 값이 아닌 추상적 추적자(Abstract Tracer) 객체를 따라 진행되어 함수의 계산 구조가 중간 표현(Intermediate Representation)으로 추출된다. 이 중간 표현은 XLA로 전달되어 컴파일된다.

4.4 순수 함수의 강제

JAX의 변환이 올바르게 동작하기 위해서는 입력 함수가 순수 함수여야 한다. 즉, 동일한 입력에 대해 항상 동일한 출력을 반환해야 하며, 외부 상태를 읽거나 수정해서는 안 된다. 이는 일반적인 명령형 코드에서 흔히 사용되는 가변 상태나 부수 효과(Side Effect)와 충돌할 수 있으나, 그 대가로 변환의 수학적 정합성과 컴파일러의 최적화 자유도가 확보된다. 이러한 설계는 함수형 프로그래밍의 원칙을 텐서 연산에 일관되게 적용한 결과이다.

5. 세 프레임워크의 비교 정리

세 프레임워크의 텐서 자료 구조에 관한 핵심 차이점을 다음 표로 정리한다.

항목PyTorchTensorFlow 2.xJAX
텐서 클래스torch.Tensortf.Tensor(불변), tf.Variable(가변)jax.Array(불변)
변경 가능성변경 가능, 인플레이스 연산 지원텐서는 불변, 변수는 가변완전 불변, 함수형 갱신만 허용
기본 실행 방식즉시 실행(동적 그래프)즉시 실행, tf.function으로 추적 컴파일즉시 실행 후 jit 변환을 통한 컴파일
자동 미분 인터페이스텐서의 requires_grad.backward()GradientTape 컨텍스트와 tape.gradient함수 변환 jax.grad, jax.jacrev
미분의 함수성객체 기반컨텍스트 기반함수 변환 기반
컴파일 백엔드TorchScript, Inductor, XLA(선택적)XLA(통합)XLA(필수적)
디바이스 추상화device 속성, 명시적 이동with tf.device(...) 또는 자동 배치device_put, pmap을 통한 명시적 분산
패러다임 지향성명령형, 객체 지향명령형/선언형 절충함수형

이러한 차이는 단순한 구현 세부의 문제가 아니라, 각 프레임워크가 추구하는 설계 철학의 표현이다. PyTorch는 일반적인 파이썬 객체 지향 프로그래밍에 가까운 직관적인 사용성을 우선시하며, TensorFlow는 명령형과 선언형 사이의 절충을 통해 산업적 배포와 성능을 동시에 추구하고, JAX는 함수형 패러다임과 변환의 합성을 통해 수학적 명료성과 컴파일러 친화성을 극대화한다.

6. 상호 운용성과 표준화

세 프레임워크는 서로 독립적으로 발전하였지만, 데이터 교환을 위한 상호 운용성(Interoperability) 표준이 점차 자리잡고 있다. DLPack은 서로 다른 프레임워크 사이에서 GPU 텐서를 데이터 복사 없이 공유하기 위한 공통 자료 구조 사양이며, PyTorch, TensorFlow, JAX 모두 DLPack 변환 함수를 제공한다. 또한 ONNX(Open Neural Network Exchange)는 학습된 모델을 프레임워크 사이에서 이식하기 위한 공통 모델 표현 형식으로, 텐서의 형상, 자료형, 연산 정의를 포함하는 표준 스키마를 제공한다. 이러한 표준화는 텐서 자료 구조가 프레임워크별로 상이하더라도, 그 핵심 의미론—다차원 배열로서의 텐서와 그 위의 표준 연산—은 공통적이라는 사실을 반영한다.

7. 결론

PyTorch, TensorFlow, JAX의 텐서 자료 구조는 모두 동일한 수학적 대상인 다차원 배열을 표현하지만, 변경 가능성과 자동 미분의 결합 방식, 즉시 실행과 컴파일의 균형, 디바이스 추상화의 설계 등에서 본질적으로 다른 선택을 내리고 있다. 이러한 차이는 사용자의 작업 흐름, 모델의 표현 방식, 성능 최적화 전략에 직접적인 영향을 주며, 어떤 프레임워크를 선택하는가는 단순한 도구 선택이 아니라 어떤 프로그래밍 패러다임 위에서 딥러닝 시스템을 설계할 것인가에 대한 본질적 결정이 된다.