27.22 행렬 곱셈의 계산 복잡도와 효율적 알고리즘
1. 표준 행렬 곱셈의 계산 복잡도
두 행렬 \mathbf{A} \in \mathbb{R}^{m \times p}와 \mathbf{B} \in \mathbb{R}^{p \times n}의 곱 \mathbf{C} = \mathbf{A}\mathbf{B} \in \mathbb{R}^{m \times n}을 구하는 표준 알고리즘을 분석하자. 결과 행렬의 (i,j) 원소는 다음과 같이 계산된다.
c_{ij} = \sum_{k=1}^{p} a_{ik} b_{kj}
이 내적 연산에는 p번의 곱셈과 (p-1)번의 덧셈이 필요하다. 결과 행렬의 원소 수가 m \times n개이므로, 총 연산 횟수는 다음과 같다.
T(m, p, n) = m \cdot n \cdot (2p - 1) \approx 2mnp
정사각 행렬의 경우 m = p = n이면 T(n) = 2n^3 - n^2이므로, 표준 행렬 곱셈의 시간 복잡도는 O(n^3)이다. 이는 입력 크기 O(n^2)에 비해 초선형(superlinear)이며, 대규모 행렬에 대해서는 상당한 연산 부담이 된다.
구체적인 수치로 살펴보면, n = 1000인 정사각 행렬 곱셈에는 약 2 \times 10^9회의 부동소수점 연산(FLOP)이 필요하고, n = 10000이면 약 2 \times 10^{12} FLOP이 필요하다. 딥러닝에서 대규모 모델의 가중치 행렬 크기가 수만에서 수십만에 이르는 점을 고려하면, 행렬 곱셈의 효율화는 실용적으로 매우 중요한 과제이다.
2. 슈트라센 알고리즘
1969년 Volker Strassen은 논문 “Gaussian Elimination is Not Optimal“에서 표준 O(n^3) 복잡도를 최초로 개선하는 알고리즘을 제시하였다. 슈트라센 알고리즘(Strassen algorithm)의 핵심 아이디어는 2 \times 2 블록 행렬 곱셈에서 8번의 곱셈을 7번으로 줄이는 것이다.
두 2 \times 2 행렬의 곱을 다음과 같이 분할하자.
\begin{pmatrix} \mathbf{C}_{11} & \mathbf{C}_{12} \\ \mathbf{C}_{21} & \mathbf{C}_{22} \end{pmatrix} = \begin{pmatrix} \mathbf{A}_{11} & \mathbf{A}_{12} \\ \mathbf{A}_{21} & \mathbf{A}_{22} \end{pmatrix} \begin{pmatrix} \mathbf{B}_{11} & \mathbf{B}_{12} \\ \mathbf{B}_{21} & \mathbf{B}_{22} \end{pmatrix}
슈트라센은 7개의 보조 행렬 곱을 정의한다.
\begin{aligned} \mathbf{M}_1 &= (\mathbf{A}_{11} + \mathbf{A}_{22})(\mathbf{B}_{11} + \mathbf{B}_{22}) \\ \mathbf{M}_2 &= (\mathbf{A}_{21} + \mathbf{A}_{22})\mathbf{B}_{11} \\ \mathbf{M}_3 &= \mathbf{A}_{11}(\mathbf{B}_{12} - \mathbf{B}_{22}) \\ \mathbf{M}_4 &= \mathbf{A}_{22}(\mathbf{B}_{21} - \mathbf{B}_{11}) \\ \mathbf{M}_5 &= (\mathbf{A}_{11} + \mathbf{A}_{12})\mathbf{B}_{22} \\ \mathbf{M}_6 &= (\mathbf{A}_{21} - \mathbf{A}_{11})(\mathbf{B}_{11} + \mathbf{B}_{12}) \\ \mathbf{M}_7 &= (\mathbf{A}_{12} - \mathbf{A}_{22})(\mathbf{B}_{21} + \mathbf{B}_{22}) \end{aligned}
이로부터 결과 블록들을 덧셈만으로 구한다.
\begin{aligned} \mathbf{C}_{11} &= \mathbf{M}_1 + \mathbf{M}_4 - \mathbf{M}_5 + \mathbf{M}_7 \\ \mathbf{C}_{12} &= \mathbf{M}_3 + \mathbf{M}_5 \\ \mathbf{C}_{21} &= \mathbf{M}_2 + \mathbf{M}_4 \\ \mathbf{C}_{22} &= \mathbf{M}_1 - \mathbf{M}_2 + \mathbf{M}_3 + \mathbf{M}_6 \end{aligned}
이 과정을 재귀적으로 적용하면 전체 복잡도는 마스터 정리(master theorem)에 의해 다음과 같이 결정된다.
T(n) = 7T\left(\frac{n}{2}\right) + O(n^2) \implies T(n) = O(n^{\log_2 7}) \approx O(n^{2.807})
이는 표준 알고리즘 대비 이론적으로 유의미한 개선이다. 다만 슈트라센 알고리즘은 덧셈 횟수의 증가, 수치적 불안정성, 캐시 효율성 저하 등의 실용적 단점이 있어, 실제로는 행렬 크기가 충분히 클 때(일반적으로 n > 500 이상)에만 표준 알고리즘보다 우위를 보인다.
3. 이론적 하한과 최적 알고리즘 탐색
행렬 곱셈의 시간 복잡도 하한을 O(n^\omega)로 표기할 때, 지수 \omega를 행렬 곱셈 지수(matrix multiplication exponent)라 한다. 이론적으로 \omega \geq 2임은 자명하다. 결과 행렬의 n^2개 원소를 모두 계산해야 하기 때문이다.
슈트라센 이후 지속적인 연구를 통해 \omega의 상한이 점진적으로 낮아져 왔다.
| 연구자 | 연도 | 지수 상한 |
|---|---|---|
| Strassen | 1969 | 2.807 |
| Pan | 1980 | 2.796 |
| Coppersmith-Winograd | 1990 | 2.376 |
| Stothers | 2010 | 2.3737 |
| Williams | 2012 | 2.3729 |
| Alman-Vassilevska Williams | 2021 | 2.3728 |
| Duan-Wu-Zhou | 2023 | 2.371866 |
현재까지 \omega = 2를 달성하는 알고리즘은 발견되지 않았으며, \omega = 2가 실현 가능한지조차 미해결 문제로 남아 있다. 다만 이론적 알고리즘들은 대부분 상수 계수가 매우 커서 실용적인 크기의 행렬에는 적용하기 어렵다는 한계가 있다.
4. 실용적 고속화 기법
실제 고성능 행렬 곱셈 구현에서는 이론적 알고리즘보다 하드웨어 특성을 활용한 최적화가 더 큰 성능 향상을 가져온다.
블록 분할과 캐시 최적화: 행렬을 작은 블록(tile)으로 분할하여 각 블록이 CPU 캐시에 적재되도록 하면, 메모리 접근 지연을 최소화할 수 있다. n \times n 행렬 곱셈에서 블록 크기를 b로 설정하면, 캐시 미스(cache miss) 횟수가 O(n^3 / (b\sqrt{M}))으로 감소한다. 여기서 M은 캐시 크기이다.
BLAS와 LAPACK: 고성능 선형대수 라이브러리인 BLAS(Basic Linear Algebra Subprograms)의 Level-3 루틴 GEMM(General Matrix Multiply)은 행렬 곱셈의 사실상 표준 구현이다. Intel MKL, OpenBLAS, ATLAS 등의 최적화된 BLAS 구현은 특정 하드웨어 아키텍처에 맞추어 레지스터 활용, SIMD 벡터 명령어, 루프 언롤링(loop unrolling) 등을 적용하여 이론적 최대 성능(peak FLOPS)에 근접한 처리량을 달성한다.
GPU 병렬화: 행렬 곱셈은 본질적으로 높은 수준의 데이터 병렬성을 지니므로, GPU 가속에 적합하다. NVIDIA의 cuBLAS 라이브러리는 수천 개의 CUDA 코어를 활용하여 대규모 행렬 곱셈을 병렬 처리한다. Tensor Core를 탑재한 최신 GPU는 혼합 정밀도(mixed precision) 행렬 곱셈을 지원하여, FP16 입력과 FP32 누적을 통해 처리량을 배가시킨다.
5. 딥러닝에서의 행렬 곱셈 최적화
딥러닝 학습과 추론에서 행렬 곱셈은 전체 연산 시간의 대부분을 차지한다. 트랜스포머(Transformer) 기반 모델의 어텐션(attention) 메커니즘에서는 쿼리, 키, 값 행렬 간의 곱셈이 시퀀스 길이의 제곱에 비례하는 복잡도를 가진다. 시퀀스 길이 L, 은닉 차원 d에 대해 어텐션 연산의 복잡도는 O(L^2 d)이다.
이를 개선하기 위한 다양한 근사 기법이 연구되어 왔다. 희소 어텐션(sparse attention)은 어텐션 행렬의 특정 패턴만 계산하여 복잡도를 O(L\sqrt{L} \cdot d)로 줄인다. 저순위 근사(low-rank approximation)는 어텐션 행렬을 낮은 순위의 행렬 곱으로 분해하여 O(Lrd) (r \ll L) 복잡도를 달성한다. FlashAttention은 알고리즘 수준에서 메모리 접근 패턴을 최적화하여 IO 복잡도를 O(L^2 d^2 / M)으로 줄이는 접근법이다. 여기서 M은 SRAM 크기이다.
양자화(quantization)를 통한 저정밀도 행렬 곱셈도 중요한 최적화 방법이다. 32비트 부동소수점 대신 8비트 정수나 4비트 정수로 가중치를 표현하면, 메모리 사용량과 연산 처리량 모두에서 이점을 얻을 수 있다. INT8 행렬 곱셈은 FP32 대비 이론적으로 4배의 처리량 향상이 가능하다.