🚀 Deep Dive into OpenAI Triton: Building FlashAttention & High-Performance CFD from Scratch
Author: Technical Learner (Root User) Environment: Python 3.12 + NVIDIA GPU (A100/H100 Level)
📥 Resources
The full source code for this tutorial is available for download: Download Source Code (zhangdonghao-triton.tar.gz)
Introduction
In the fields of High-Performance Computing (HPC) and Deep Learning, CUDA C++ has long been the dominant force. However, its steep learning curve has deterred many Python developers. OpenAI Triton changes the game. It allows us to write GPU Kernels using Python syntax while achieving performance comparable to or even exceeding cuBLAS.
This article documents my complete learning journey. We will start from basic memory pointers, progress to hand-writing the core LLM operator FlashAttention, and finally cross domains to implement a CFD (Computational Fluid Dynamics) Simulator that is 100x faster than Numpy.
Project Directory Overview
Here is the workspace structure after completing this practice:
root@g88:/workspace/triton# ls
01_vector_add.py # Basic: Vector Addition
02_benchmark_add.py # Advanced: Performance Benchmark
03_matmul.py # Core: Matrix Multiplication
04_matmul_autotune.py # Optimization: Auto-Tuner
05_softmax.py # Fusion: Softmax Operator
06_flash_attention.py # Challenge: FlashAttention V2
07_cfd_lbm.py # Application: LBM Fluid Simulation
results.html # Performance Report
triton_cfd.mp4 # Fluid Simulation Video
vector-add-performance.csvChapter 1: Mindset Shift —— From Tensor to Block
In PyTorch, we are used to highly encapsulated operations like z = x + y. But in Triton, we need to build a Block-Based Programming mental model.
- Grid: The entire computation task.
- Block: The task is sliced into independent blocks, executed in parallel.
- Pointer: We need to manually calculate memory address offsets.
1.1 Vector Add
This is the "Hello World" of GPU programming. We need to handle pointer offsets and masking to prevent out-of-bounds access.
File: 01_vector_add.py
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, y_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr # Compile-time constant, used as block size
):
# 1. Get the current program ID (similar to CUDA blockIdx)
pid = tl.program_id(axis=0)
# 2. Calculate the data range for the current block
# E.g., BLOCK_SIZE=1024, pid=0 -> handles [0, 1024)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 3. Create Mask
# Crucial step: prevent processing invalid memory beyond n_elements
mask = offsets < n_elements
# 4. Load Data
# Pointer arithmetic: Base address + Offset
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# 5. Compute
output = x + y
# 6. Write back to VRAM (Store)
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
n_elements = output.numel()
# Calculate how many blocks are needed (ceil)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
add_kernel[grid](
x, y, output,
n_elements,
BLOCK_SIZE=1024
)
return output
# --- Test ---
if __name__ == "__main__":
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_triton = add(x, y)
output_torch = x + y
if torch.allclose(output_triton, output_torch):
print("✅ Success! Triton Vector Add matches PyTorch.")
else:
print("❌ Failed!")Chapter 2: Benchmark —— Saturating VRAM Bandwidth
Vector addition is a typical Memory Bound task. The metric is not computation speed, but GB/s.
File: 02_benchmark_add.py
import torch
import triton
import triton.language as tl
# Import the add function defined above
# ... (Insert add_kernel and add code here) ...
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'], # x-axis: vector length
x_vals=[2**i for i in range(12, 28, 1)],
x_log=True,
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'PyTorch'],
ylabel='GB/s',
plot_name='vector-add-performance',
)
)
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
if provider == 'triton':
# Assuming add function is available in current scope
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
# Bandwidth formula: 3x data volume (read x + read y + write out) / time
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)
if __name__ == '__main__':
benchmark.run(print_data=True, show_plots=False)Real-world Data: On high-end GPUs, we achieved ~1700 GB/s bandwidth, matching the native PyTorch operator.
Chapter 3: Matrix Multiplication (MatMul) —— Compute Intensive Task
Matrix multiplication is the cornerstone of deep learning. The challenges here are 2D Pointer Arithmetic and Tiling.
Core Technique: Using Broadcasting to generate 2D pointer grids.
ofs_am[:, None](Column Vector) +ofs_bn[None, :](Row Vector) = 2D Matrix Block Address.
File: 04_matmul_autotune.py (With Auto-Tuning)
import torch
import triton
import triton.language as tl
# Autotune: Automatically find the best config for current hardware
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
pid = tl.program_id(axis=0)
# --- Swizzling: Improve L2 Cache Hit Rate ---
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# Generate pointer offsets
ofs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
ofs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
ofs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (ofs_am[:, None] * stride_am + ofs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (ofs_k[:, None] * stride_bk + ofs_bn[None, :] * stride_bn)
# Initialize accumulator as float32 for precision
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Loop over K dimension
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Simplified masking, assuming K is multiple of BLOCK_SIZE_K
a = tl.load(a_ptrs, mask=ofs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=ofs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# Core computation: Tensor Cores MatMul
accumulator += tl.dot(a, b)
# Advance pointers
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(tl.float16)
# Write back results
ofs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
ofs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * ofs_cm[:, None] + stride_cn * ofs_cn[None, :]
c_mask = (ofs_cm[:, None] < M) & (ofs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# (Host code omitted)Chapter 4: FlashAttention —— The Pinnacle of Operator Fusion
This is the core technology accelerating LLM inference today. Standard Attention requires generating a huge $N \times N$ matrix, causing massive VRAM overhead. FlashAttention uses Tiling and Online Softmax techniques to completely avoid reading/writing this large intermediate matrix from VRAM.
Key Points:
- Online Softmax: Dynamically update max values and sums within the loop.
- Re-scaling: When a new max value is found, previous accumulated results need to be rescaled.
File: 06_flash_attention.py (Final version with fixed Scaling)
import torch
import triton
import triton.language as tl
import math
@triton.jit
def flash_attn_kernel(
Q, K, V, Out,
stride_qm, stride_qk,
stride_kn, stride_kk,
stride_vn, stride_vk,
stride_om, stride_on,
Z, H, N_CTX,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
sm_scale: tl.constexpr
):
# Grid: Each Program handles one Block of Q
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# Locate Batch and Head
q_offset = off_hz * stride_qm * N_CTX
k_offset = off_hz * stride_kn * N_CTX
v_offset = off_hz * stride_vn * N_CTX
o_offset = off_hz * stride_om * N_CTX
# Generate Pointers
qs = Q + q_offset + start_m * BLOCK_M * stride_qm
rng_m = tl.arange(0, BLOCK_M)
rng_d = tl.arange(0, HEAD_DIM)
qs_ptrs = qs + (rng_m[:, None] * stride_qm + rng_d[None, :] * stride_qk)
# Initialize stats: m (max), l (sum), acc (output accumulator)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# Load Q (Reused)
q = tl.load(qs_ptrs)
# Loop over K and V (K dimension)
for start_n in range(0, tl.cdiv(N_CTX, BLOCK_N)):
# Load K
ks = K + k_offset + start_n * BLOCK_N * stride_kn
ks_ptrs = ks + (rng_d[:, None] * stride_kk + tl.arange(0, BLOCK_N)[None, :] * stride_kn)
k = tl.load(ks_ptrs)
# Compute QK^T
qk = tl.dot(q, k)
qk *= sm_scale # Attention Scaling (1/sqrt(d))
# --- Online Softmax Update ---
m_ij = tl.max(qk, 1) # Max of current block
m_new = tl.maximum(m_i, m_ij) # Global max update
p = tl.exp(qk - m_new[:, None])
# Correction factor: alpha = exp(old_max - new_max)
alpha = tl.exp(m_i - m_new)
l_new = alpha * l_i + tl.sum(p, 1)
# Rescale previous Accumulator
acc = acc * alpha[:, None]
# Load V and Accumulate
vs = V + v_offset + start_n * BLOCK_N * stride_vn
vs_ptrs = vs + (tl.arange(0, BLOCK_N)[:, None] * stride_vn + rng_d[None, :] * stride_vk)
v = tl.load(vs_ptrs)
acc += tl.dot(p.to(tl.float16), v)
# Update state
l_i = l_new
m_i = m_new
# Final Normalization
acc = acc / l_i[:, None]
# Store Output
os = Out + o_offset + start_m * BLOCK_M * stride_om
os_ptrs = os + (rng_m[:, None] * stride_om + rng_d[None, :] * stride_on)
tl.store(os_ptrs, acc.to(tl.float16))
# Host Wrapper & Benchmark omitted...
# Measured Performance: ~130 TFLOPSChapter 5: Cross-Domain Application —— High-Performance Fluid Simulation (CFD)
Triton isn't just for AI. We implemented an LBM (Lattice Boltzmann Method) fluid simulator. Compared to pure Python loops, the Triton version leverages GPU parallelism, mapping each grid point computation to a thread, achieving an ultra-high frame rate of 1765 steps/sec.
Technical Highlights:
- 1D Flattening: Flattening the 2D grid into a 1D array for processing, greatly improving memory access efficiency.
- Stride Management: Handling stride control when mixing 3D (flow field) and 2D (obstacle) data.
File: 07_cfd_lbm.py (Optimized Stable Version)
import torch
import triton
import triton.language as tl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
# --- LBM Kernel (1D Version) ---
@triton.jit
def lbm_kernel(
f_in_ptr, f_out_ptr, obstacle_ptr,
nx, ny,
f_stride_x, f_stride_y, f_stride_k,
o_stride_x, o_stride_y,
OMEGA: tl.constexpr,
BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < nx * ny
# Coordinate Mapping
x = offsets % nx
y = offsets // nx
# Read Obstacle
obs_ptr = obstacle_ptr + x * o_stride_x + y * o_stride_y
is_solid = tl.load(obs_ptr, mask=mask, other=0).to(tl.int1)
# ... (Macroscopic variable calc & Streaming logic) ...
# Key is using tl.load to pull data directly from offset addresses (Pull Scheme)
# Collision & Bounce-back
# Use tl.where to handle conditional branching, avoiding warp divergence
# out = tl.where(is_solid, f_bounced, f_collision)Summary
Through this series of practices, we have proven that OpenAI Triton is the perfect bridge between the Python ecosystem and high-performance GPU computing. It can not only optimize Transformer models but also easily handle complex scientific computing tasks.
Key Takeaways:
- Performance: Reach hardware limit bandwidth with just a few lines of Python.
- Flexibility: Easily implement complex operator fusions like FlashAttention.
- Universality: CFD simulation proves its potential beyond AI.
If you want to master GPU programming from scratch, do not hesitate, Triton is the best starting point.
