PyTorch torch.distributed

PyTorch torch.distributed

1. 현대 딥러닝에서 분산 훈련의 필요성

현대 딥러닝 생태계에서 torch.distributed가 왜 필수적인 구성 요소인지 이해하는 것으로 본 안내서를 시작한다. 이 섹션에서는 일반적인 문제 상황에서 출발하여 PyTorch가 제공하는 구체적인 해결책까지 다룬다.

1.1 규모의 두 가지 과제: 데이터와 모델 크기

딥러닝 분야는 모델의 크기와 복잡성이 기하급수적으로 증가하는 추세를 보인다. 특히 트랜스포머(Transformer) 아키텍처를 기반으로 하는 대규모 언어 모델(LLM)은 수십억에서 수조 개의 파라미터를 가지며, 이는 단일 가속기(GPU 등)의 메모리 용량을 훨씬 초과한다.1 이러한 거대 모델을 훈련시키기 위해서는 모델 자체를 여러 GPU에 분할하여 저장하고 계산을 수행하는 **모델 병렬 처리(Model Parallelism)**가 필수적이다.

동시에, 모델의 성능을 극한으로 끌어올리기 위해 사용되는 데이터셋의 규모 또한 웹 스케일로 확장되었다. 방대한 양의 데이터를 단일 GPU로 학습시키는 것은 현실적으로 불가능에 가까운 시간이 소요될 수 있다.2 이 문제를 해결하기 위해 여러 GPU가 데이터의 일부를 나누어 동시에 처리함으로써 전체 학습 처리량을 높이고 훈련 시간을 단축하는 **데이터 병렬 처리(Data Parallelism)**가 필요하다.

결론적으로, 분산 훈련은 더 이상 선택적인 최적화 기법이 아니라, 최첨단 AI 연구와 실제 상용 서비스를 위한 근본적인 필수 기술로 자리 잡았다.5

1.2 torch.distributed 소개: PyTorch의 네이티브 병렬 처리 툴킷

torch.distributed는 여러 프로세스와 머신 클러스터에 걸쳐 계산을 병렬화하기 위한 PyTorch의 핵심 라이브러리다.1 이 패키지는 메시지 전달 시맨틱을 활용하여 각 프로세스가 다른 모든 프로세스와 데이터를 통신할 수 있도록 지원한다. 이는 단일 머신 내에서의 실행으로 제한되는 torch.multiprocessing과 달리, 다양한 통신 백엔드를 지원하며 여러 머신에 걸쳐 유연하게 확장할 수 있다는 장점을 가진다.7

torch.distributed 패키지는 다음과 같은 계층적 구조를 가진다:

  1. 저수준 통신 API (c10d): send, recv, all_reduce와 같은 기본적인 통신 연산을 제공하는 백엔드다. 이를 통해 사용자는 자신만의 복잡한 병렬 처리 로직을 구현할 수 있다.

  2. 고수준 병렬 처리 모듈: DistributedDataParallel (DDP), FullyShardedDataParallel (FSDP)와 같이 널리 사용되는 병렬 처리 패턴을 쉽게 적용할 수 있도록 추상화된 모듈을 제공한다.

  3. 실행 유틸리티: torchrun과 같은 도구를 통해 여러 노드에 걸친 분산 훈련 작업을 간편하게 시작하고 관리할 수 있다.1

이러한 계층적 구조는 사용자가 필요에 따라 적절한 수준의 추상화를 선택하여 분산 훈련을 구현할 수 있도록 돕는다. 모델 아키텍처의 복잡성과 하드웨어 성능의 발전은 서로 영향을 주고받으며 진화해왔다. torch.distributed의 등장은 이러한 진화의 직접적인 결과물이며, 대규모 AI 분야 전체를 가능하게 하는 핵심 인프라다. 따라서 이 패키지를 이해하는 것은 단순히 API를 배우는 것을 넘어, 현대 AI가 마주한 확장성 한계를 해결하는 근본적인 엔지니어링 솔루션을 이해하는 것과 같다.

2. torch.distributed의 핵심 개념

분산 PyTorch 프로그램을 효과적으로 구현하고 디버깅하기 위해서는 핵심 용어와 아키텍처에 대한 명확한 이해가 필수적이다. 이 섹션에서는 분산 환경의 기본 구성 요소를 정의한다.

2.1 프로세스 그룹: 통신의 장

**프로세스 그룹(Process Group)**은 서로 통신할 수 있는 프로세스들의 집합을 의미한다.5 분산 작업이 시작될 때, 참여하는 모든 프로세스를 포함하는 기본 그룹이 생성되며, 이를 **월드(world)**라고 부른다.7 사용자는 dist.new_group(ranks) 함수를 사용하여 전체 프로세스 중 일부만을 포함하는 새로운 하위 그룹을 생성할 수 있다. 이는 데이터 병렬 처리와 모델 병렬 처리를 결합하는 하이브리드 병렬 처리와 같이 복잡한 통신 패턴을 구현할 때 유용하다.7

2.2 프로세스 식별: World Size, Rank, Local Rank

분산 환경의 각 프로세스는 고유한 식별자를 통해 자신을 인식하고 다른 프로세스와 통신한다.

  • World Size: 분산 작업에 참여하는 전체 프로세스의 수를 의미한다. 일반적으로 훈련에 사용되는 총 GPU의 개수와 동일하다.10 예를 들어, 4개의 GPU를 가진 노드 2개를 사용한다면 world_size는 8이 된다.

  • Rank: 월드 그룹 내에서 각 프로세스에 부여되는 전역적이고 유일한 ID다. 범위는 0부터 world_size - 1까지다.10 rank=0인 프로세스는 종종 마스터(master) 역할을 맡아 로깅, 체크포인트 저장 등의 작업을 조율한다.

  • Local Rank: 단일 노드(머신) 내에서 각 프로세스에 부여되는 지역적인 ID다. 범위는 0부터 해당 노드의 프로세스 수 - 1까지다.4 이 값은 특정 노드 내의 각 프로세스에 특정 GPU를 할당하는 데 매우 중요하다(예:torch.cuda.set_device(local_rank)).

2.3 통신 백엔드: 통신의 엔진

백엔드는 실제 프로세스 간 통신 프로토콜을 구현하는 라이브러리다.6 사용자는 자신의 하드웨어와 작업 환경에 가장 적합한 백엔드를 선택해야 한다.

  • NCCL (NVIDIA Collective Communications Library): NVIDIA GPU 환경에서 다중 GPU 통신을 위한 업계 표준이다. NVLink와 같은 고속 인터커넥트를 활용하여 all_reduce와 같은 집합적 통신 연산에 대해 고도로 최적화된 성능을 제공한다.5

  • Gloo: CPU와 GPU 모두에서 작동하는 플랫폼 독립적인 백엔드다. 안정적인 기본 옵션이며, CPU 기반의 분산 작업이나 NCCL을 사용할 수 없는 환경에서 대안으로 사용된다.11

  • MPI (Message Passing Interface): 고성능 컴퓨팅(HPC) 분야의 표준 통신 인터페이스다. PyTorch에서도 사용할 수 있지만, GPU 통신을 위해서는 CUDA를 인식하는(CUDA-aware) MPI 구현이 필요하며, 종종 소스 코드로부터 직접 PyTorch를 빌드해야 하는 번거로움이 있다.11

2.4 랑데부: 분산 환경 초기화

모든 분산 프로그램은 torch.distributed.init_process_group() 함수를 호출하여 시작해야 한다. 이 함수는 모든 프로세스가 서로를 발견하고 통신 그룹을 형성하도록 동기화하는, 이른바 랑데부(rendezvous) 과정을 수행한다.6

주요 초기화 방법은 다음과 같다:

  • 환경 변수 (env://): 가장 일반적이고 권장되는 방법이다. torchrun과 같은 실행 유틸리티가 MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK와 같은 환경 변수를 자동으로 설정해주면, init_process_group 함수가 이 변수들을 읽어 초기화를 수행한다.5

  • TCP (tcp://): rank=0 프로세스의 IP 주소와 포트를 모든 프로세스에 수동으로 명시해야 한다.9

  • 공유 파일 시스템 (file://): 모든 노드에서 접근 가능한 공유 파일 시스템의 특정 파일을 사용하여 프로세스 간 정보를 교환한다. 다음 실행 시 충돌을 피하기 위해 사용된 파일을 직접 관리해야 하는 주의가 필요하다.16

2.4.1 표 1: torch.distributed 핵심 용어 요약

용어설명예시 (노드 2개 * GPU 4개)
프로세스 그룹통신이 가능한 프로세스들의 집합. 기본 그룹은 ’월드’다.8개 프로세스 전체의 집합.
World Size분산 작업에 참여하는 전체 프로세스의 수.world_size = 8
Rank0부터 world\_size - 1까지의 전역 고유 ID.랭크 0, 1, 2, 3, 4, 5, 6, 7.
Local Rank노드 내에서 부여되는 0부터 노드당 GPU 수 - 1까지의 지역 고유 ID.노드 0의 랭크 0, 1, 2, 3; 노드 1의 랭크 0, 1, 2, 3.
백엔드통신 연산을 수행하는 기본 라이브러리 (예: NCCL, Gloo).NVIDIA GPU의 경우 backend='nccl'.
랑데부프로세스들이 서로를 발견하는 초기 동기화 단계.MASTER_ADDRMASTER_PORT를 사용하는 env:// 방식.

3. 통신 프리미티브 심층 분석

이 섹션에서는 분산 통신의 기본 구성 요소인 점대점(point-to-point) 연산과 집합적(collective) 연산을 자세히 살펴본다. 이를 이해하는 것은 사용자 정의 병렬 처리 전략을 구현하거나 DDP와 같은 고수준 프레임워크를 디버깅하는 데 핵심적이다.

3.1 점대점(Point-to-Point) 통신: 세밀한 제어

점대점(P2P) 통신은 두 특정 프로세스 간의 데이터 전송을 의미한다.6 이는 특정 프로세스 쌍 사이의 정교한 데이터 흐름이 필요할 때 사용된다.

  • 블로킹(Blocking) 연산: dist.send(tensor, dst)dist.recv(tensor, src)는 통신이 완료될 때까지 코드 실행을 차단한다. 이는 로직을 이해하기는 쉽지만, 송수신 순서가 맞지 않으면 프로세스들이 서로를 무한정 기다리는 교착 상태(deadlock)에 빠질 위험이 있다.6

  • 논블로킹(Non-blocking) 연산: dist.isend()dist.irecv()는 즉시 Work 객체를 반환하고 코드 실행을 계속한다. 이를 통해 통신과 계산을 중첩시켜 성능을 향상시킬 수 있다. 단, 수신된 텐서를 사용하거나 전송한 텐서를 수정하기 전에는 반드시 req.wait()를 호출하여 실제 통신이 완료되었는지 확인해야 한다.7

  • 사용 사례: P2P 통신은 한 GPU에서 다음 GPU로 활성화(activation) 값을 전달해야 하는 모델 병렬 처리나, 복잡한 비동기 알고리즘을 구현하는 데 필수적이다.6

3.2 집합적(Collective) 통신: 동기식 병렬 처리의 엔진

집합적 통신은 그룹 내의 모든 프로세스가 동시에 참여하는 연산이다.7 이는 데이터 병렬 처리와 같이 모든 프로세스가 동기화되어야 하는 작업의 근간을 이룬다. 집합적 연산 중 데이터를 결합하는 연산(reduction)에는 어떤 수학적 연산을 적용할지 지정해야 하며, 이를 위해 ReduceOp (SUM, PRODUCT, MAX, MIN)가 사용된다.7

3.3 주요 집합적 연산의 구조

각 연산에 대해 명확한 정의, 데이터 흐름도, 그리고 코드 예시를 통해 설명한다.

  • dist.broadcast(tensor, src): 원본 프로세스(src)의 텐서를 그룹 내 다른 모든 프로세스에 복사한다. 모든 모델 복제본이 rank=0의 가중치와 동일한 초기값으로 시작하도록 보장하는 데 사용된다.7

  • dist.reduce(tensor, dst, op): 모든 프로세스로부터 텐서를 모아 op 연산을 적용하고, 최종 결과를 목적지 프로세스(dst)에만 저장한다. 마스터 프로세스에서 통계나 손실 값을 수집할 때 유용하다.7

  • dist.all_reduce(tensor, op): DDP의 핵심 연산이다. 모든 프로세스로부터 텐서를 모아 op 연산을 적용한 뒤, 최종 결과를 다시 모든 프로세스에 분배한다. 연산은 인플레이스(in-place)로 수행된다. 모든 모델 복제본의 그래디언트를 평균내어 각 옵티마이저가 동일한 가중치 업데이트를 수행하도록 하는 데 사용된다.6

  • dist.scatter(tensor, scatter_list, src): 원본 프로세스(src)가 텐서 리스트(scatter_list)를 여러 조각으로 나누어 그룹 내 각 프로세스(자신 포함)에 하나씩 분배한다. 마스터 프로세스가 전체 데이터 배치를 각 워커에게 나누어 줄 때 사용될 수 있다.7

  • dist.gather(tensor, gather_list, dst): scatter의 역연산이다. 각 프로세스가 자신의 텐서를 목적지 프로세스(dst)로 보내면, dst는 이를 gather_list에 수집한다. 모든 워커의 출력이나 예측 결과를 단일 프로세스로 모아 평가할 때 사용된다.7

  • dist.all_gather(tensor_list, tensor): gather와 유사하지만, 모든 프로세스가 다른 모든 프로세스로부터 온 텐서들의 전체 리스트를 각자 가지게 된다. 모든 워커가 다른 모든 워커의 중간 결과물을 필요로 하는 계산에 사용된다.

all_reduce 연산은 수학적으로 reducebroadcast를 수행하는 것과 동일하지만, 실제 구현은 훨씬 효율적이다. NCCL과 같은 고성능 백엔드는 링-올리듀스(Ring-Allreduce)와 같은 알고리즘을 사용하여 단일 마스터 노드의 병목 현상을 피한다.19 이 알고리즘에서는 각 프로세스가 이웃 프로세스와 링(ring) 토폴로지를 형성하여 데이터 청크를 주고받는다. 이를 통해 통신 부하가 모든 프로세스에 고르게 분산되어 GPU 수가 증가함에 따라 확장성이 훨씬 좋아진다. 이는 분산 시스템 설계의 핵심 원칙인 ’중앙 집중식 병목 현상 회피’를 보여주는 대표적인 사례로, 사용자가 수동으로 reducebroadcast를 구현하는 대신 네이티브 all_reduce를 사용해야 하는 이유가 단지 편의성 때문이 아니라 성능과 확장성 때문임을 시사한다.

3.3.1 표 2: 점대점 통신 vs. 집합적 통신

특징점대점 (P2P) 통신집합적 통신
범위두 개의 특정 프로세스 (src, dst).그룹 내의 모든 프로세스.
동기화참여하는 두 프로세스만 동기화.참여하는 모든 프로세스를 동기화 (배리어 역할).
주요 사용 사례모델 병렬 처리, 파이프라인 병렬 처리, 비동기 알고리즘.데이터 병렬 처리 (DDP), 동기식 그래디언트 평균화.
핵심 함수send, recv, isend, irecv.all_reduce, broadcast, scatter, gather.
복잡성높음; 교착 상태를 피하기 위해 송수신 순서를 신중하게 관리해야 함.낮음; 복잡한 패턴을 단일 함수 호출로 추상화.

4. 병렬 처리 패러다임

통신 프리미티브를 이해했다면, 이제 이를 적용하여 신경망을 훈련시키는 고수준 전략들을 살펴볼 차례다.

4.1 데이터 병렬 처리: 더 많은 GPU, 더 빠른 훈련

데이터 병렬 처리의 핵심 아이디어는 모델을 모든 GPU에 복제하고, 각 GPU에는 입력 데이터 배치의 다른 일부를 공급하는 것이다. 역전파 후, 옵티마이저가 가중치를 업데이트하기 전에 모든 GPU의 그래디언트를 평균내어 모든 모델이 동기화된 상태를 유지하도록 한다.1

4.1.1 과거의 접근법: torch.nn.DataParallel (DP)

DataParallel은 단일 프로세스, 다중 스레드 방식으로 작동한다. 주 스레드가 데이터를 여러 GPU로 분산(scatter)시키고, 매 순전파마다 모델을 복제하며, 출력을 다시 모아(gather) 주 GPU에서 그래디언트를 집계한다.22 그러나 이 방식은 파이썬의 전역 인터프리터 락(GIL)으로 인한 경합, 주 GPU의 병목 현상, 다중 노드 미지원 등 심각한 한계를 가진다.23 현재는 레거시로 간주되며 사용을 지양해야 한다.

4.1.2 현대의 표준: torch.nn.parallel.DistributedDataParallel (DDP)

DDP는 다중 프로세스 아키텍처를 기반으로 하여 GIL 문제를 근본적으로 해결한다.22 각 프로세스/GPU는 자신만의 옵티마이저를 가지며, 역전파 과정에서 효율적인 all_reduce 연산을 통해 그래디언트가 동기화된 후 독립적으로 가중치를 업데이트한다.5 DDP의 가장 큰 장점 중 하나는 계산과 통신의 중첩이다. Autograd 훅(hook)을 사용하여 특정 레이어의 그래디언트 계산이 완료되는 즉시 해당 그래디언트에 대한 all_reduce 통신을 시작함으로써, 전체 역전파가 끝날 때까지 기다리지 않고 유휴 시간을 최소화하여 효율성을 극대화한다.22

4.2 모델 병렬 처리: 거대 모델 훈련

모델 병렬 처리는 단일 거대 모델을 여러 GPU에 분할하는 전략이다. 모델의 다른 레이어나 구성 요소가 서로 다른 장치에 상주하게 된다.1

  • 단순 모델 병렬 처리: nn.Module의 일부를 cuda:0에, 다른 일부를 cuda:1에 배치하고, forward 메소드에서 중간 활성화 텐서를 명시적으로 장치 간에 이동시키는 방식이다.2 이 방식의 가장 큰 단점은 파이프라인 버블(bubble) 현상으로, 한 GPU가 계산하는 동안 다른 GPU들은 이전 단계의 결과가 도착하기를 기다리며 유휴 상태에 빠지게 된다.2

  • 파이프라인 병렬 처리: GPU 유휴 시간을 줄이는 진보된 모델 병렬 처리 기법이다. 미니 배치를 더 작은 마이크로 배치로 분할하고, 이를 모델 스테이지 파이프라인에 конвейер(컨베이어 벨트)처럼 연속적으로 투입한다. 이를 통해 여러 스테이지(GPU)가 서로 다른 마이크로 배치를 동시에 처리하여 효율성을 높인다.1

  • 텐서 병렬 처리: 단일 모듈(예: 거대한 nn.Linear 레이어) 내의 계산 자체를 병렬화하는 기법이다. 가중치 행렬 자체를 여러 GPU에 걸쳐 분할(shard)하고, 각 GPU에서 부분 계산을 수행한 뒤 집합적 통신을 통해 결과를 결합한다. 이는 거대한 트랜스포머 레이어의 규모를 확장하는 데 매우 중요하다.1

4.3 하이브리드 전략: 두 세계의 장점 결합

**완전 샤딩 데이터 병렬 처리(Fully Sharded Data Parallelism, FSDP)**는 데이터 병렬 처리와 모델 병렬 처리의 아이디어를 결합한 강력한 하이브리드 접근법이다. FSDP는 DDP처럼 데이터를 여러 랭크에 분할하지만, 추가적으로 모델 파라미터, 그래디언트, 옵티마이저 상태까지도 모든 랭크에 걸쳐 분할(shard)하여 저장한다.1

FSDP의 동작 방식은 다음과 같다: 각 랭크는 평소에는 모델의 일부 조각만을 메모리에 유지한다. 순전파/역전파 과정에서 특정 레이어를 계산해야 할 때만 all_gather 연산을 통해 해당 레이어의 전체 파라미터를 일시적으로 재구성하고, 계산이 끝나면 즉시 메모리에서 해제한다. 이를 통해 GPU당 최대 메모리 사용량을 극적으로 줄여, 데이터 병렬 처리의 확장성을 유지하면서도 단일 GPU 메모리에 담을 수 없는 거대 모델의 훈련을 가능하게 한다.

4.3.1 표 4: DataParallel vs. DistributedDataParallel

특징torch.nn.DataParallel (DP)torch.nn.parallel.DistributedDataParallel (DDP)
병렬 처리 모델단일 프로세스, 다중 스레드다중 프로세스
GIL 경합있음, 파이썬의 전역 인터프리터 락이 병목이 될 수 있음.없음, 각 프로세스가 독립적인 인터프리터를 가짐.
그래디언트 동기화그래디언트를 단일 주 프로세스/GPU로 모아서 처리.all_reduce를 통해 모든 프로세스에 걸쳐 분산 평균화.
병목 현상주 프로세스/GPU가 그래디언트 집계의 병목 지점.분산형 구조로 단일 장애 지점 없음.
네트워크 활용비효율적; 통신과 계산의 중첩이 어려움.고효율; Autograd 훅을 통해 통신과 계산을 중첩.
확장성낮음. 단일 노드에서만 작동. GPU 수가 늘면 성능 저하.뛰어남. 다중 노드, 다중 GPU 환경으로 확장 가능.
권장 사항레거시. 심각한 사용 사례에는 절대적으로 피해야 함.모든 데이터 병렬 훈련에 권장되는 표준.

5. DDP 구현: 단계별 가이드

이 섹션은 DDP 훈련 스크립트를 설정하고 실행하기 위한 실용적인 핸즈온 튜토리얼이다. 이전 섹션의 개념들을 완전하고 실행 가능한 예제로 통합한다.

5.1 DDP 상용구: setup()cleanup()

분산 훈련 스크립트는 일반적으로 프로세스 그룹을 초기화하고 정리하는 함수로 시작하고 끝난다.

  • setup 함수: dist.init_process_group을 호출하여 랑데부를 수행하고, 환경 변수로부터 local_rank를 받아 torch.cuda.set_device(local_rank)를 통해 현재 프로세스가 사용할 GPU를 지정한다.4

  • cleanup 함수: 훈련이 끝난 후 dist.destroy_process_group을 호출하여 사용된 리소스를 정상적으로 해제한다.4

5.2 모델 래핑: model = DDP(model, device_ids=[local_rank])

이 한 줄의 코드는 DDP의 핵심이다. 모델은 먼저 model.to(local_rank)를 통해 올바른 GPU로 이동된 후, DDP로 래핑되어야 한다.4 DDP 생성자는 내부적으로 다음 작업을 수행한다:

  1. rank=0의 모델 상태(파라미터)를 다른 모든 프로세스에 브로드캐스트하여 모든 복제본이 동일한 지점에서 훈련을 시작하도록 보장한다.22

  2. 모델 파라미터에 Autograd 훅을 등록하여, 역전파 과정에서 그래디언트가 계산될 때마다 동기화 작업이 트리거되도록 설정한다.5

5.3 데이터셋 분할: DistributedSampler

데이터 병렬 처리에서 각 GPU는 데이터의 서로 다른 부분을 처리해야 한다. 이를 보장하지 않으면 모든 프로세스가 동일한 데이터로 훈련하게 되어 자원 낭비가 발생한다.4

torch.utils.data.distributed.DistributedSampler는 데이터셋 인덱스를 world_size에 따라 자동으로 분할하여 각 랭크에 고유한 데이터 서브셋을 제공한다.4

DataLoader를 생성할 때 sampler 인자로 DistributedSampler 인스턴스를 전달해야 한다. 이때 DataLoadershuffle 인자는 반드시 False로 설정해야 한다. 데이터 셔플링은 샘플러가 자체적으로 관리하기 때문이다. 또한, 매 에포크 시작 시 sampler.set_epoch(epoch)를 호출하여 에포크마다 데이터가 다르게 셔플링되도록 보장하는 것이 중요하다.26

5.4 DDP 훈련 루프

DDP의 훈련 루프는 놀랍게도 단일 GPU 훈련 루프와 거의 동일하다. 사용자는 평소처럼 loss.backward()optimizer.step()을 호출하면 된다. 모든 복잡한 작업은 DDP 래퍼가 내부적으로 처리한다. loss.backward()가 호출되면 등록된 Autograd 훅이 트리거되어 백그라운드에서 그래디언트 all_reduce 통신이 발생한다. backward() 함수가 반환될 시점에는 이미 각 파라미터의 .grad 속성에 모든 프로세스의 그래디언트가 평균된 값이 저장되어 있다.22

DDP와 DistributedSampler의 설계는 강력한 추상화의 예시다. 프로세스 동기화, 네트워크 통신, 데이터 샤딩과 같은 엄청난 복잡성을 단 몇 줄의 API 호출 뒤에 숨김으로써, 사용자는 분산 시스템 엔지니어링이 아닌 모델 로직 자체에 집중할 수 있다. 이러한 뛰어난 개발자 경험은 PyTorch가 연구 커뮤니티에서 지배적인 위치를 차지하게 된 주된 이유 중 하나이며, 대규모 실험의 진입 장벽을 낮춤으로써 AI 분야의 혁신 속도를 가속화했다.

5.5 완전한 주석 코드 예제

다음은 위에서 설명한 모든 요소를 통합한 완전한 DDP 훈련 스크립트다.

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size):
"""분산 환경을 초기화한다."""
os.environ = 'localhost'
os.environ = '12355'
# 프로세스 그룹 초기화
dist.init_process_group("nccl", rank=rank, world_size=world_size)
# 현재 프로세스에 GPU 할당
torch.cuda.set_device(rank)

def cleanup():
"""프로세스 그룹을 정리한다."""
dist.destroy_process_group()

class ToyModel(nn.Module):
"""간단한 선형 모델."""
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))

class ToyDataset(Dataset):
"""임의의 데이터를 생성하는 더미 데이터셋."""
def __init__(self, size=1000):
self.size = size
self.data = torch.randn(size, 10)
self.labels = torch.randn(size, 5)

def __len__(self):
return self.size

def __getitem__(self, idx):
return self.data[idx], self.labels[idx]

def train(rank, world_size):
print(f"Rank {rank}에서 DDP 훈련 시작...")
setup(rank, world_size)

# 1. 모델 생성 및 GPU로 이동
model = ToyModel().to(rank)
# 2. DDP로 모델 래핑
ddp_model = DDP(model, device_ids=[rank])

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

# 3. DistributedSampler로 데이터셋 준비
dataset = ToyDataset()
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, sampler=sampler)

# 4. 훈련 루프
for epoch in range(10):
sampler.set_epoch(epoch) # 매 에포크마다 셔플링 보장
for data, labels in dataloader:
data = data.to(rank)
labels = labels.to(rank)

optimizer.zero_grad()
outputs = ddp_model(data)
loss = loss_fn(outputs, labels)
loss.backward() # 이 시점에 그래디언트 동기화 발생
optimizer.step()

if rank == 0:
print(f"Rank {rank}, Epoch {epoch}, Loss: {loss.item()}")

cleanup()

if __name__ == "__main__":
import torch.multiprocessing as mp
world_size = torch.cuda.device_count()
# mp.spawn을 사용하여 여러 프로세스를 생성하고 train 함수 실행
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

6. torchrun으로 분산 작업 실행 및 관리

올바르게 작성된 DDP 스크립트라도 정확하게 실행되지 않으면 무용지물이다. 이 섹션에서는 분산 작업을 실행하는 운영 측면을 다룬다.

6.1 launch.py에서 torchrun으로: 현대의 표준

과거에는 torch.distributed.launch 유틸리티가 사용되었지만, 현재는 torchrun (이전 torch.distributed.elastic의 일부)이 권장되는 실행기다.24

torchrun은 더 나은 오류 복원력, 탄력적 훈련(작업 중간에 워커 수 조정) 지원, 그리고 더 깔끔한 커맨드 라인 인자를 제공한다는 장점이 있다.26

6.2 단일 노드, 다중 GPU 실행

일반적인 단일 노드 실행 명령어는 다음과 같다 14:

torchrun --nproc_per_node=NUM_GPUS your_script.py --arg1...

torchrunNUM_GPUS개의 프로세스를 생성하고, 각 프로세스에 대해 LOCAL_RANK, RANK, WORLD_SIZE 환경 변수를 자동으로 설정한다.

6.3 스케일 아웃: 다중 노드, 다중 GPU 실행

다중 노드 실행은 참여하는 모든 노드에서 다음 명령어를 실행해야 한다 14:

torchrun --nnodes=TOTAL_NODES --nproc_per_node=GPUS_PER_NODE \
--node_rank=NODE_INDEX --rdzv_id=JOB_ID \
--rdzv_backend=c10d --rdzv_endpoint=MASTER_NODE_IP:PORT \
your_script.py...

랑데부(rdzv) 관련 인자들은 다음과 같은 의미를 가진다:

  • --nnodes: 작업에 참여하는 총 머신(노드)의 수.

  • --node_rank: 현재 실행 중인 머신의 랭크 (0부터 nnodes-1까지).

  • --rdzv_id: 다른 작업과의 혼선을 방지하기 위한 고유한 작업 ID.

  • --rdzv_endpoint: 모든 프로세스가 만나는 지점 역할을 하는 단일 노드(일반적으로 node_rank=0 머신)의 IP 주소와 포트.

6.3.1 표 5: 주요 torchrun 커맨드 라인 인자

인자설명사용 사례설정되는 환경 변수
--nproc_per_node각 노드에서 실행할 프로세스 수 (보통 GPU 수).단일 & 다중 노드LOCAL_RANK, WORLD_SIZE
--nnodes작업에 참여하는 총 노드(머신) 수.다중 노드WORLD_SIZE
--node_rank현재 노드의 랭크.다중 노드RANK
--rdzv_id훈련 작업을 위한 고유 ID.다중 노드-
--rdzv_backend랑데부 프로세스를 위한 백엔드 (예: c10d).다중 노드-
--rdzv_endpoint랑데부를 위한 마스터 노드의 IP:PORT.다중 노드MASTER_ADDR, MASTER_PORT

7. 고급 기법 및 운영 모범 사례

이 마지막 기술 섹션에서는 견고하고 효율적이며 메모리에 최적화된 분산 훈련을 위한 중요한 주제들을 다룬다.

7.1 상태 관리: DDP 체크포인팅 최종 가이드

DDP 환경에서 모델의 상태를 저장하고 불러오는 것은 신중하게 다루어야 한다.

  • 황금률: 한 랭크에서 저장하고, 모든 랭크에서 불러오기. DDP가 모든 모델을 동기화 상태로 유지하므로, 이들의 state_dict는 동일하다. rank=0에서만 저장하면 중복된 디스크 쓰기와 잠재적인 경쟁 상태를 피할 수 있어 효율적이고 안전하다.4

  • 저장: if rank == 0: torch.save(ddp_model.module.state_dict(), PATH) 패턴을 사용한다. DDP 래퍼 자체가 아닌 내부 모델의 상태를 저장하기 위해 ddp_model.module.state_dict()를 사용하는 점에 유의해야 한다.

  • 불러오기: 불러오기는 더 섬세한 과정이 필요하다.

  1. 모든 프로세스에서 dist.barrier()를 호출하여 rank=0이 저장을 완료할 때까지 다른 프로세스들이 기다리도록 한다. 이는 불완전한 파일을 읽으려는 시도를 방지한다.22

  2. 각 프로세스는 torch.load(PATH, map_location=f'cuda:{local_rank}')와 같이 map_location을 지정하여 체크포인트를 자신의 할당된 GPU로 직접 로드한다.

  3. 그 후 model.load_state_dict(checkpoint)를 통해 로컬 모델 인스턴스에 상태를 로드한다. map_location 인자는 각 프로세스가 자신의 GPU에 가중치를 올바르게 로드하도록 보장하는 데 매우 중요하다.22

  • 대규모 모델을 위한 torch.distributed.checkpoint: FSDP와 같이 샤딩된 거대 모델의 경우, torch.distributed.checkpoint는 랭크별로 파일을 생성하는 디렉토리 기반의 더 현대적이고 확장 가능한 체크포인팅 솔루션을 제공한다.29

7.2 활성화 체크포인팅을 통한 메모리 최적화

활성화 체크포인팅(또는 그래디언트 체크포인팅)은 메모리와 계산 시간 사이의 트레이드오프를 이용하는 기법이다. 순전파 과정에서 모든 레이어의 중간 활성화 값을 저장하는 대신, 일부만 저장하고 나머지는 역전파 과정에서 필요할 때 재계산한다. 이를 통해 GPU 메모리 사용량을 크게 줄일 수 있다.5

torch.utils.checkpoint.checkpoint를 사용하여 트랜스포머 블록과 같이 메모리 집약적인 모델의 특정 부분을 래핑함으로써 이 기법을 적용할 수 있다.30

7.3 흔한 함정과 디버깅

  • 교착 상태 (Deadlocks): 랭크 간에 집합적 통신 호출이 일치하지 않거나, P2P 송수신 순서가 잘못되었을 때 자주 발생한다.

  • 초기화 오류: MASTER_ADDR/PORT가 잘못 설정되었거나 방화벽이 통신을 차단하는 경우 발생할 수 있다.

  • CUDA 오류: torch.cuda.set_device(local_rank)를 잊으면 모든 프로세스가 GPU 0을 사용하려고 시도하여 메모리 부족(OOM) 오류가 발생할 수 있다.

  • 배치 크기와 학습률: DDP 사용 시, 전역 배치 크기는 GPU당 배치 크기 * world_size가 된다. 동일한 훈련 동역학을 유지하기 위해 종종 학습률을 전역 배치 크기에 비례하여 선형적으로 조정(linear scaling)할 필요가 있다.28

8. 분산 훈련 전략 수립

이 결론 섹션에서는 본 안내서의 내용을 종합하여 독자가 자신의 문제에 맞는 올바른 도구를 선택하는 데 도움이 되는 고수준의 의사 결정 프레임워크를 제시한다.

8.1 개념 종합

저수준 통신 프리미티브에서 고수준 병렬 처리 추상화에 이르기까지의 여정을 요약한다. 대부분의 사용 사례에서는 DDP가 최적의 솔루션이지만, 그 기본 원리를 이해하는 것이 더 복잡한 문제를 해결하는 열쇠임을 다시 한번 강조한다.

8.2 병렬 처리 의사 결정 프레임워크

다음은 사용 사례에 맞는 병렬 처리 전략을 선택하기 위한 간단한 의사 결정 트리다.

  1. 모델이 단일 GPU 메모리에 맞는가?
  • 예: 훈련 속도가 너무 느린가?

  • 그렇다면, 더 많은 GPU/노드로 확장하기 위해 DDP를 사용한다.

  • 아니오: 다음 질문으로 넘어간다.

  1. 모델이 단일 GPU에 비해 너무 큰가?
  • 예:

  • FSDP로 시작한다. 이는 메모리 절약과 구현의 용이성 사이에서 훌륭한 균형을 제공한다.

  • FSDP로도 부족하거나 추가 최적화가 필요한 경우:

  • 모델이 명확한 순차 구조를 가졌다면 파이프라인 병렬 처리를 고려한다.

  • 개별 레이어 자체가 병목이라면 텐서 병렬 처리를 탐색한다.

이 프레임워크는 안내서에서 다룬 상세한 설명을 바탕으로 독자에게 실질적이고 실행 가능한 지침을 제공한다.

9. 참고 자료

  1. PyTorch Distributed Overview — PyTorch Tutorials 2.8.0+cu128 documentation, https://docs.pytorch.org/tutorials/beginner/dist_overview.html
  2. Model Parallelism vs Data Parallelism: Examples - Analytics Yogi, https://vitalflux.com/model-parallelism-data-parallelism-differences-examples/
  3. Data Parallelism and Model Parallelism - czxttkl, https://czxttkl.com/2021/08/09/data-parallelism-and-model-parallelism/
  4. HOWTO: PyTorch Distributed Data Parallel (DDP) | Ohio Supercomputer Center, https://www.osc.edu/resources/getting_started/howto/howto_pytorch_distributed_data_parallel_ddp
  5. Distributed Parallel Training: PyTorch Multi-GPU Setup in Kaggle T4x2 - LearnOpenCV, https://learnopencv.com/distributed-parallel-training-pytorch-multi-gpu-setup/
  6. PyTorch Distributed: A Bottom-Up Perspective | by Hao | Medium, https://medium.com/@eeyuhao/pytorch-distributed-a-bottom-up-perspective-e3159ee2c2e7
  7. Writing Distributed Applications with PyTorch — PyTorch Tutorials …, https://pytorch-cn.com/tutorials/intermediate/dist_tuto.html
  8. Writing Distributed Applications with PyTorch - ShaLab, https://shalab.usc.edu/writing-distributed-applications-with-pytorch/
  9. Distributed Applications with PyTorch - GeeksforGeeks, https://www.geeksforgeeks.org/deep-learning/distributed-applications-with-pytorch/
  10. python - In distributed computing, what are world size and rank …, https://stackoverflow.com/questions/58271635/in-distributed-computing-what-are-world-size-and-rank
  11. Writing Distributed Applications with PyTorch — PyTorch Tutorials 2.8.0+cu128 documentation, https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html
  12. In the PyTorch Distributed Data Parallel (DDP) tutorial, how does setup know it’s rank?, https://codemia.io/knowledge-hub/path/in_the_pytorch_distributed_data_parallel_ddp_tutorial_how_does_setup_know_its_rank
  13. Multi node PyTorch Distributed Training Guide For People In A Hurry - Lambda, https://lambda.ai/blog/multi-node-pytorch-distributed-training-guide
  14. Using torchrun for Distributed Training Dongda’s homepage, https://dongdongbh.tech/blog/torchrun/
  15. Does torch.distributed support point-to-point communication for GPU? - Stack Overflow, https://stackoverflow.com/questions/70390019/does-torch-distributed-support-point-to-point-communication-for-gpu
  16. Distributed communication package - torch.distributed — PyTorch master documentation, https://alband.github.io/doc_view/distributed.html
  17. Understanding PyTorch’s Distributed Communication Package: Powering all distributed model training | by Arjun Agarwal | Medium, https://medium.com/@arjunagarwal899/understanding-pytorchs-distributed-communication-package-powering-all-distributed-model-training-340d6b553faf
  18. Collective Communication in Distributed Systems with PyTorch - Roboflow Blog, https://blog.roboflow.com/collective-communication-distributed-systems-pytorch/
  19. Communication is the Key to Success - GPU Puzzlers, http://www.gpupuzzlers.com/posts/collectives/
  20. Writing Distributed Applications with PyTorch, https://sebarnold.net/blog/writing_distributed_apps_pytorch_20170614/note.pdf
  21. PyTorch Distributed: Experiences on Accelerating Data Parallel Training - arXiv, https://arxiv.org/pdf/2006.15704
  22. Getting Started with Distributed Data Parallel — PyTorch Tutorials …, https://pytorch-cn.com/tutorials/intermediate/ddp_tutorial.html
  23. DataParallel vs DistributedDataParallel - distributed - PyTorch Forums, https://discuss.pytorch.org/t/dataparallel-vs-distributeddataparallel/77891
  24. Launching a distributed training run - Harold Benoit, https://haroldbenoit.com/notes/ML/Engineering/Pytorch/Launching-a-distributed-training-run
  25. Introduction to Model Parallelism - Amazon SageMaker AI - AWS Documentation, https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-intro.html
  26. The Practical Guide to Distributed Training using PyTorch — Part 2 …, https://medium.com/the-owl/the-practical-guide-to-distributed-training-using-pytorch-part-2-on-a-single-node-using-torchrun-9e794baa0410
  27. Python API: torch.utils.data.distributed.DistributedSampler Class Reference - Caffe2, https://caffe2.ai/doxygen-python/html/classtorch_1_1utils_1_1data_1_1distributed_1_1_distributed_sampler.html
  28. Properly implementing DDP in training loop with cleanup, barrier, and its expected output, https://discuss.pytorch.org/t/properly-implementing-ddp-in-training-loop-with-cleanup-barrier-and-its-expected-output/146465
  29. torch.distributed.checkpoint — PyTorch 2.8 documentation, https://docs.pytorch.org/docs/stable/distributed.checkpoint.html
  30. PyTorch Activation Checkpointing: Complete Guide | by Hey Amit - Medium, https://medium.com/@heyamit10/pytorch-activation-checkpointing-complete-guide-58d4f3b15a3d