Skip to content

🚀 Deep Dive into OpenAI Triton: Building FlashAttention & High-Performance CFD from Scratch

Author: Technical Learner (Root User) Environment: Python 3.12 + NVIDIA GPU (A100/H100 Level)

📥 Resources

The full source code for this tutorial is available for download: Download Source Code (zhangdonghao-triton.tar.gz)

Introduction

In the fields of High-Performance Computing (HPC) and Deep Learning, CUDA C++ has long been the dominant force. However, its steep learning curve has deterred many Python developers. OpenAI Triton changes the game. It allows us to write GPU Kernels using Python syntax while achieving performance comparable to or even exceeding cuBLAS.

This article documents my complete learning journey. We will start from basic memory pointers, progress to hand-writing the core LLM operator FlashAttention, and finally cross domains to implement a CFD (Computational Fluid Dynamics) Simulator that is 100x faster than Numpy.

Project Directory Overview

Here is the workspace structure after completing this practice:

bash
root@g88:/workspace/triton# ls
01_vector_add.py          # Basic: Vector Addition
02_benchmark_add.py       # Advanced: Performance Benchmark
03_matmul.py              # Core: Matrix Multiplication
04_matmul_autotune.py     # Optimization: Auto-Tuner
05_softmax.py             # Fusion: Softmax Operator
06_flash_attention.py     # Challenge: FlashAttention V2
07_cfd_lbm.py             # Application: LBM Fluid Simulation
results.html              # Performance Report
triton_cfd.mp4            # Fluid Simulation Video
vector-add-performance.csv

Chapter 1: Mindset Shift —— From Tensor to Block

In PyTorch, we are used to highly encapsulated operations like z = x + y. But in Triton, we need to build a Block-Based Programming mental model.

  1. Grid: The entire computation task.
  2. Block: The task is sliced into independent blocks, executed in parallel.
  3. Pointer: We need to manually calculate memory address offsets.

1.1 Vector Add

This is the "Hello World" of GPU programming. We need to handle pointer offsets and masking to prevent out-of-bounds access.

File: 01_vector_add.py

python
import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(
    x_ptr, y_ptr, output_ptr, 
    n_elements, 
    BLOCK_SIZE: tl.constexpr # Compile-time constant, used as block size
):
    # 1. Get the current program ID (similar to CUDA blockIdx)
    pid = tl.program_id(axis=0)
    
    # 2. Calculate the data range for the current block
    # E.g., BLOCK_SIZE=1024, pid=0 -> handles [0, 1024)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    
    # 3. Create Mask
    # Crucial step: prevent processing invalid memory beyond n_elements
    mask = offsets < n_elements

    # 4. Load Data
    # Pointer arithmetic: Base address + Offset
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # 5. Compute
    output = x + y

    # 6. Write back to VRAM (Store)
    tl.store(output_ptr + offsets, output, mask=mask)

def add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    n_elements = output.numel()
    
    # Calculate how many blocks are needed (ceil)
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    
    add_kernel[grid](
        x, y, output, 
        n_elements, 
        BLOCK_SIZE=1024
    )
    return output

# --- Test ---
if __name__ == "__main__":
    torch.manual_seed(0)
    size = 98432
    x = torch.rand(size, device='cuda')
    y = torch.rand(size, device='cuda')
    output_triton = add(x, y)
    output_torch = x + y
    if torch.allclose(output_triton, output_torch):
        print("✅ Success! Triton Vector Add matches PyTorch.")
    else:
        print("❌ Failed!")

Chapter 2: Benchmark —— Saturating VRAM Bandwidth

Vector addition is a typical Memory Bound task. The metric is not computation speed, but GB/s.

File: 02_benchmark_add.py

python
import torch
import triton
import triton.language as tl
# Import the add function defined above

# ... (Insert add_kernel and add code here) ...

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],          # x-axis: vector length
        x_vals=[2**i for i in range(12, 28, 1)], 
        x_log=True,
        line_arg='provider',
        line_vals=['triton', 'torch'],
        line_names=['Triton', 'PyTorch'],
        ylabel='GB/s',
        plot_name='vector-add-performance',
    )
)
def benchmark(size, provider):
    x = torch.rand(size, device='cuda', dtype=torch.float32)
    y = torch.rand(size, device='cuda', dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]
    
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
    if provider == 'triton':
        # Assuming add function is available in current scope
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
    
    # Bandwidth formula: 3x data volume (read x + read y + write out) / time
    gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
    return gbps(ms), gbps(max_ms), gbps(min_ms)

if __name__ == '__main__':
    benchmark.run(print_data=True, show_plots=False)

Real-world Data: On high-end GPUs, we achieved ~1700 GB/s bandwidth, matching the native PyTorch operator.


Chapter 3: Matrix Multiplication (MatMul) —— Compute Intensive Task

Matrix multiplication is the cornerstone of deep learning. The challenges here are 2D Pointer Arithmetic and Tiling.

Core Technique: Using Broadcasting to generate 2D pointer grids.

  • ofs_am[:, None] (Column Vector) + ofs_bn[None, :] (Row Vector) = 2D Matrix Block Address.

File: 04_matmul_autotune.py (With Auto-Tuning)

python
import torch
import triton
import triton.language as tl

# Autotune: Automatically find the best config for current hardware
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64,  'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    pid = tl.program_id(axis=0)
    
    # --- Swizzling: Improve L2 Cache Hit Rate ---
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # Generate pointer offsets
    ofs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    ofs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    ofs_k = tl.arange(0, BLOCK_SIZE_K)

    a_ptrs = a_ptr + (ofs_am[:, None] * stride_am + ofs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (ofs_k[:, None] * stride_bk + ofs_bn[None, :] * stride_bn)

    # Initialize accumulator as float32 for precision
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # Loop over K dimension
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # Simplified masking, assuming K is multiple of BLOCK_SIZE_K
        a = tl.load(a_ptrs, mask=ofs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=ofs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        
        # Core computation: Tensor Cores MatMul
        accumulator += tl.dot(a, b)
        
        # Advance pointers
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
        
    c = accumulator.to(tl.float16)
    
    # Write back results
    ofs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    ofs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * ofs_cm[:, None] + stride_cn * ofs_cn[None, :]
    c_mask = (ofs_cm[:, None] < M) & (ofs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

# (Host code omitted)

Chapter 4: FlashAttention —— The Pinnacle of Operator Fusion

This is the core technology accelerating LLM inference today. Standard Attention requires generating a huge $N \times N$ matrix, causing massive VRAM overhead. FlashAttention uses Tiling and Online Softmax techniques to completely avoid reading/writing this large intermediate matrix from VRAM.

Key Points:

  • Online Softmax: Dynamically update max values and sums within the loop.
  • Re-scaling: When a new max value is found, previous accumulated results need to be rescaled.

File: 06_flash_attention.py (Final version with fixed Scaling)

python
import torch
import triton
import triton.language as tl
import math

@triton.jit
def flash_attn_kernel(
    Q, K, V, Out,
    stride_qm, stride_qk, 
    stride_kn, stride_kk,
    stride_vn, stride_vk,
    stride_om, stride_on,
    Z, H, N_CTX, 
    HEAD_DIM: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
    sm_scale: tl.constexpr 
):
    # Grid: Each Program handles one Block of Q
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    
    # Locate Batch and Head
    q_offset = off_hz * stride_qm * N_CTX
    k_offset = off_hz * stride_kn * N_CTX
    v_offset = off_hz * stride_vn * N_CTX
    o_offset = off_hz * stride_om * N_CTX

    # Generate Pointers
    qs = Q + q_offset + start_m * BLOCK_M * stride_qm
    rng_m = tl.arange(0, BLOCK_M)
    rng_d = tl.arange(0, HEAD_DIM)
    qs_ptrs = qs + (rng_m[:, None] * stride_qm + rng_d[None, :] * stride_qk)

    # Initialize stats: m (max), l (sum), acc (output accumulator)
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)

    # Load Q (Reused)
    q = tl.load(qs_ptrs)

    # Loop over K and V (K dimension)
    for start_n in range(0, tl.cdiv(N_CTX, BLOCK_N)):
        # Load K
        ks = K + k_offset + start_n * BLOCK_N * stride_kn
        ks_ptrs = ks + (rng_d[:, None] * stride_kk + tl.arange(0, BLOCK_N)[None, :] * stride_kn)
        k = tl.load(ks_ptrs)

        # Compute QK^T
        qk = tl.dot(q, k)
        qk *= sm_scale # Attention Scaling (1/sqrt(d))

        # --- Online Softmax Update ---
        m_ij = tl.max(qk, 1)          # Max of current block
        m_new = tl.maximum(m_i, m_ij) # Global max update
        
        p = tl.exp(qk - m_new[:, None])
        
        # Correction factor: alpha = exp(old_max - new_max)
        alpha = tl.exp(m_i - m_new) 
        l_new = alpha * l_i + tl.sum(p, 1)
        
        # Rescale previous Accumulator
        acc = acc * alpha[:, None] 
        
        # Load V and Accumulate
        vs = V + v_offset + start_n * BLOCK_N * stride_vn
        vs_ptrs = vs + (tl.arange(0, BLOCK_N)[:, None] * stride_vn + rng_d[None, :] * stride_vk)
        v = tl.load(vs_ptrs)
        
        acc += tl.dot(p.to(tl.float16), v)

        # Update state
        l_i = l_new
        m_i = m_new

    # Final Normalization
    acc = acc / l_i[:, None]

    # Store Output
    os = Out + o_offset + start_m * BLOCK_M * stride_om
    os_ptrs = os + (rng_m[:, None] * stride_om + rng_d[None, :] * stride_on)
    tl.store(os_ptrs, acc.to(tl.float16))

# Host Wrapper & Benchmark omitted... 
# Measured Performance: ~130 TFLOPS

Chapter 5: Cross-Domain Application —— High-Performance Fluid Simulation (CFD)

Triton isn't just for AI. We implemented an LBM (Lattice Boltzmann Method) fluid simulator. Compared to pure Python loops, the Triton version leverages GPU parallelism, mapping each grid point computation to a thread, achieving an ultra-high frame rate of 1765 steps/sec.

Technical Highlights:

  • 1D Flattening: Flattening the 2D grid into a 1D array for processing, greatly improving memory access efficiency.
  • Stride Management: Handling stride control when mixing 3D (flow field) and 2D (obstacle) data.

File: 07_cfd_lbm.py (Optimized Stable Version)

python
import torch
import triton
import triton.language as tl
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# --- LBM Kernel (1D Version) ---
@triton.jit
def lbm_kernel(
    f_in_ptr, f_out_ptr, obstacle_ptr,
    nx, ny,
    f_stride_x, f_stride_y, f_stride_k,
    o_stride_x, o_stride_y,
    OMEGA: tl.constexpr,
    BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < nx * ny
    
    # Coordinate Mapping
    x = offsets % nx
    y = offsets // nx

    # Read Obstacle
    obs_ptr = obstacle_ptr + x * o_stride_x + y * o_stride_y
    is_solid = tl.load(obs_ptr, mask=mask, other=0).to(tl.int1)

    # ... (Macroscopic variable calc & Streaming logic) ...
    # Key is using tl.load to pull data directly from offset addresses (Pull Scheme)

    # Collision & Bounce-back
    # Use tl.where to handle conditional branching, avoiding warp divergence
    # out = tl.where(is_solid, f_bounced, f_collision)

Summary

Through this series of practices, we have proven that OpenAI Triton is the perfect bridge between the Python ecosystem and high-performance GPU computing. It can not only optimize Transformer models but also easily handle complex scientific computing tasks.

Key Takeaways:

  1. Performance: Reach hardware limit bandwidth with just a few lines of Python.
  2. Flexibility: Easily implement complex operator fusions like FlashAttention.
  3. Universality: CFD simulation proves its potential beyond AI.

If you want to master GPU programming from scratch, do not hesitate, Triton is the best starting point.

AI-HPC Organization