본 글은 Triton Official Doc (https://triton-lang.org/main/getting-started/tutorials/index.html) 과
CUDA MODE github (https://github.com/cuda-mode)를 참조하였습니다.
예제 colab
0. 소개
Deep Learning을 공부하면서 CUDA를 공부해보면 좋을 것 같다는 생각을 하고 있던 와중 Triton이라는 library를 발견했다.
CUDA를 직접적으로 공부하기엔 너무 low-level이고 C언어 베이스기 때문에 선뜻 손이 안가던 와중에 CUDA보다는 간결하게 GPU 커널을 작성할 수 있고, Pytorch와 연동성이 높기 때문에 배울거면 Triton을 배우는게 낫겠다 싶어 공부를 시작했다.
구글링을 해보니 한국어로 Triton를 소개하는 블로그가 적은거 같아 공부를 하면서 기록을 남겨두려고 한다.
Youtube와 CUDA MODE라는 github, 그리고 Official Doc을 참고하면서 공부를 진행했다.
1. Triton 이란?
Triton은 GPU 프로그래밍을 더 편리하게 작성할 수 있는 패키지이다. Python을 베이스로 한 코드를 작성하면 그것이 ptx코드(CUDA 코드가 컴파일 되는 형식)로 컴파일이 된다.
쉽게 말하자면, CUDA는 수동 자동차고 Triton은 오토 자동차라고 생각하면 쉽다.
CUDA는 모든 것을 완전히 제어할 수 있고, 성능을 최대치로 끌어 올릴 수는 있지만, 복잡하고 디버깅이 어렵다.
Triton은 모든 것을 제어할 수는 없지만, 쉽게 괜찮은 성능을 뽑아낼 수 있고, 코드작성이 쉽고 디버깅이 쉽다.
먼저, 한가지 알아두어야 할게 Triton이 디버깅이 쉽다고 했는데, CUDA와는 다르게 Triton은 CPU와 GPU동시에 실행이 가능하다.
os.environ["TRITON_INTERPRET"] = '1'로 설정을 해두면, print와 같은 function을 활용해서 어떤식으로 코드가 진행되는지 확인이 가능하다.
그렇기 때문에, TRITON_INTERPRET을 켜두고 먼저 코드를 작성한뒤에 실행시키는걸 추천한다. (이따 보면 알겠지만, 코드가 살짝 복잡하다. 추가적으로 triton을 import 하기전에 설정을 해두어야 한다고 한다.)
import os
os.environ['TRITON_INTERPRET'] = '1'
import torch
import triton
import triton.language as tl
2. Triton의 기본 개념
Triton은 기본적으로 block-wise로 실행이 된다. 예를 들자면, x, y라는 벡터가 있고 사이즈가 8이라고 가정하고, block size가 4라면,
우리는 2개의 block을 가지고 프로그램이 실행이 된다.
pid: 0 → z[0:4] = x[0:4] + y[0:4]
pid: 1 → z[4:8] = x[4:8] + y[4:8]
보통 block_size를 결정할때 cdiv 함수를 많이 이용한다. ceiling division 함수인데, 만약 cdiv(6, 4) = 2가 되는 함수이다. (1.5 지만 반올림 하여 2)
만약에 벡터 사이즈가 6이고 블록사이즈가 4라면 2개의 블록으로 실행이 되는데 남는 메모리가 발생한다.
ex)
x = [1,2,3,4,5,6]
y = [1,2,3,4,5,6]
pid: 0 → z = [1, 2, 3, 4] + [1, 2, 3, 4]
pid: 1 → z = [5, 6, None, None] + [5, 6, None, None]
라면, pid 1에서는 Out-of-bounds 에러가 난다. 이것을 방지하기 위해서 masking을 하여 에러를 방지해야 한다.
그리고, Triton을 실행시키려면, @triton.jit으로 감싸진 Kernel 코드와 그것을 실행시킬 foward pass function을 작성해야한다.
코랩에서 본다면,
@triton.jit
def add_kernel(x_ptr, # 첫번째 input vector의 pointer
y_ptr, # 두번째 input vector의 pointer
output_ptr, # output vector의 pointer
n_elements, # 벡터들의 사이즈
BLOCK_SIZE: tl.constexpr, # 블록사이즈
):
이 부분이 kernel code 부분이고,
def add(x: torch.Tensor, y : torch.Tensor) -> torch.Tensor:
다음이 Forward pass function이다.
3. constexpr
C++에 있는 개념인데 컴파일시에 먼저 결정되는 상수라고 한다. 예를 들어,
int factorial(int n) {
return (n <= 1) ? 1 : n * factorial(n - 1);
}
int main() {
std::array<int, factorial(5)> arr; // 120 크기의 배열 생성 -> 에러!
return 0;
}
라는 코드가 있다고 가정하자. arr에 우리는 120이 들어가야 하는걸 알지만 이걸 그대로 실행하면 에러가 난다. 배열크기는 컴파일 시간에 알려져야 하는데 factorial(5)를 연산하고 나서야 배정이 되기 때문이다. 이걸 해결하기 위해선 복잡하게 template 써서 뭐 해야 한다는데 constexpr를 사용하면 간단하게 해결이 된다.
constexpr int factorial(int n) {
return (n <= 1) ? 1 : n * factorial(n - 1);
}
int main() {
std::array<int, factorial(5)> arr; // 120 크기의 배열 생성
// factorial(5)가 컴파일 시간에 계산되어 120이 됨
return 0;
}
Triton도 마찬가지로 block_size를 constexpr 형태로 미리 지정을 해줘야 에러가 나지 않는다. 그리고 컴파일 시간에 block_size가 결정이 되기 때문에 밑에 코드가 자연스럽게 돌아간다.
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
더 자세한 코드 설명은 공유한 colab을 통해서 보면 된다.