Post

From 30ms to 2ms: My Wild Ride Optimizing GPU Kernels (And Why Memory is Actually Everything)

A journey through Triton, tiling, and the day I finally understood what 'memory-bound' really means

From 30ms to 2ms: My Wild Ride Optimizing GPU Kernels (And Why Memory is Actually Everything)

Triton Matrix Multiplication — Element-wise Parallel Approach

Problem Statement

Given two matrices:

  • W of shape (2, 4)
  • H of shape (3, 4)

Goal: Compute O = H @ W.T → output shape (3, 2)

Since W and H share the same column dimension (K=4), we transpose W so the multiplication works: H (3×4) × W^T (4×2) = O (3×2)


Core Insight — How I Thought About Parallelism

The key observation from the diagram:

For each value of k, the column of W and column of H form an outer product:

1
2
3
4
5
6
7
8
9
10
11
12
k=0:
  W_col = [w00, w10]       H_col = [h00, h10, h20]

  Outer product → 6 values computed in parallel:

       h00    h10    h20
  w00  [0,0]  [0,1]  [0,2]   → 3 threads
  w10  [1,0]  [1,1]  [1,2]   → 3 threads

k=1: W_col = [w01, w11],  H_col = [h01, h11, h21]  → same 6 threads
k=2: W_col = [w02, w12],  H_col = [h02, h12, h22]  → same 6 threads
k=3: W_col = [w03, w13],  H_col = [h03, h13, h23]  → same 6 threads

K loop is sequential. Within each K, all 6 threads run in parallel.

1
2
3
4
K=0 → [6 threads parallel] → done
K=1 → [6 threads parallel] → done
K=2 → [6 threads parallel] → done
K=3 → [6 threads parallel] → done

Each thread only knows its own (i, j) position and independently accumulates its result across K iterations — no thread waits for another.


Mental Model

ConceptMeaning
program_id(0) = iWhich row of W (0 or 1)
program_id(1) = jWhich row of H (0, 1, or 2)
grid = (M, N)Total threads launched = M×N = 6
strideHow many memory positions to jump to get next element
accEach thread’s private accumulator across K loop

Why Stride?

Matrices are stored as 1D in memory:

1
2
3
4
5
W = | w00  w01  w02  w03 |   →  [w00, w01, w02, w03, w10, w11, w12, w13]
    | w10  w11  w12  w13 |        0    1    2    3    4    5    6    7

stride_row = 4   (jump 4 to go to next row)
stride_col = 1   (jump 1 to go to next column)

pid gives the dimension index. Multiplying by stride gives the actual memory address:

1
W[i, k]  →  W_ptr + i * stride_row + k * stride_col

Triton Kernel (v1 — Element-wise, No Tiling)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import triton
import triton.language as tl
import torch

@triton.jit
def matmul_kernel(
    W_ptr, H_ptr, O_ptr,
    K,
    stride_wi, stride_wk,
    stride_hi, stride_hk,
    stride_oi, stride_oj,
):
    # Each thread gets its own (i, j) — its output cell
    i = tl.program_id(0)   # row index in W  (0 or 1)
    j = tl.program_id(1)   # row index in H  (0, 1, or 2)

    acc = tl.zeros((), dtype=tl.float32)

    # K loop — sequential, but each thread runs its own independently
    for k in range(K):
        w = tl.load(W_ptr + i * stride_wi + k * stride_wk)
        h = tl.load(H_ptr + j * stride_hi + k * stride_hk)
        acc += w * h

    # Store result in O[i, j]
    tl.store(O_ptr + i * stride_oi + j * stride_oj, acc)


def run(W, H):
    M = W.shape[0]   # 2
    N = H.shape[0]   # 3
    K = W.shape[1]   # 4

    O = torch.zeros(M, N, device=W.device, dtype=torch.float32)

    # Grid = 2×3 = 6 threads total — one per output element
    grid = (M, N)

    matmul_kernel[grid](
        W, H, O,
        K,
        W.stride(0), W.stride(1),
        H.stride(0), H.stride(1),
        O.stride(0), O.stride(1),
    )
    return O


# Test
if __name__ == "__main__":
    W = torch.randn(2, 4, device='cuda')
    H = torch.randn(3, 4, device='cuda')

    O_triton = run(W, H)
    O_torch  = H @ W.T

    print("Match:", torch.allclose(O_triton, O_torch, atol=1e-3))

Limitations of This Version

  • One thread per output element → not efficient for large matrices
  • No memory reuse — same data loaded multiple times
  • Not using GPU’s full memory bandwidth potential

What’s Next — v2 (Tiled / Blocked Version)

Instead of 1 thread per element, we will assign 1 thread per tile (a chunk of the output matrix). This allows:

  • Loading a block of W and H into shared memory once
  • All threads in the block reusing that data
  • Much better memory efficiency on real hardware (RTX 3050, etc.)

The K loop will also be tiled — instead of K=0,1,2,3 one by one, we’ll process BLOCK_K elements at a time using tl.dot().

v2 coming soon.


v2 — Tiled / Blocked Version

Why Blocking?

In v1, every thread loads its own slice of W and H independently:

1
2
3
4
Thread(0,0): loads W row 0 — 1024 times
Thread(0,1): loads W row 0 — 1024 times AGAIN
Thread(0,2): loads W row 0 — 1024 times AGAIN
→ Same data pulled from memory over and over. Slow!

Block idea: One thread handles a whole tile. Load W chunk and H chunk once, compute the full tile output — way fewer memory trips.

1
2
v1: 1 thread → 1 element
v2: 1 thread → BLOCK_M × BLOCK_N elements (a tile)

New Concepts

tl.arange — get a range of indices at once:

1
2
3
4
5
6
# v1:
i = tl.program_id(0)              # just one number

# v2:
pid_m = tl.program_id(0)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)   # array of indices [0,1] or [2,3]...

[:, None] and [None, :] — build 2D address grid:

1
2
3
4
5
offs_m[:, None]   # shape (BLOCK_M, 1) — column
offs_k[None, :]   # shape (1, BLOCK_K) — row

# Together via broadcasting → full 2D grid of addresses
offs_m[:, None] + offs_k[None, :]   # shape (BLOCK_M, BLOCK_K)

tl.dot — block multiply:

1
2
3
4
5
# v1: scalar × scalar
acc += w * h

# v2: tile × tile
acc += tl.dot(w_block, h_block.T)   # (BLOCK_M, BLOCK_K) × (BLOCK_K, BLOCK_N)

Kernel

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@triton.jit
def matmul_kernel_blocked(
    W_ptr, H_ptr, O_ptr,
    M, N, K,
    stride_wi, stride_wk,
    stride_hi, stride_hk,
    stride_oi, stride_oj,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)   # tile row — belongs to W
    pid_n = tl.program_id(1)   # tile col — belongs to H

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    # Accumulator is now a tile, not a scalar
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # K loop — now in chunks of BLOCK_K
    for k in range(0, K, BLOCK_K):
        offs_k = k + tl.arange(0, BLOCK_K)

        # Load W tile: shape (BLOCK_M, BLOCK_K)
        w_ptrs = W_ptr + offs_m[:, None] * stride_wi + offs_k[None, :] * stride_wk
        w_block = tl.load(w_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)

        # Load H tile: shape (BLOCK_N, BLOCK_K)
        h_ptrs = H_ptr + offs_n[:, None] * stride_hi + offs_k[None, :] * stride_hk
        h_block = tl.load(h_ptrs, mask=(offs_n[:, None] < N) & (offs_k[None, :] < K), other=0.0)

        # Block multiply — outer product idea, now at tile level!
        acc += tl.dot(w_block, h_block.T)

    # Store entire tile
    o_ptrs = O_ptr + offs_m[:, None] * stride_oi + offs_n[None, :] * stride_oj
    tl.store(o_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))


def run_blocked(W, H, BLOCK_M=32, BLOCK_N=32, BLOCK_K=32):
    M, K = W.shape
    N    = H.shape[0]
    O    = torch.zeros(M, N, device=W.device, dtype=torch.float32)

    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_kernel_blocked[grid](
        W, H, O,
        M, N, K,
        W.stride(0), W.stride(1),
        H.stride(0), H.stride(1),
        O.stride(0), O.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
    )
    return O


# Test
if __name__ == "__main__":
    W = torch.randn(512, 256, device='cuda')
    H = torch.randn(512, 256, device='cuda')

    O_triton = run_blocked(W, H)
    O_torch  = H @ W.T

    print("Match:", torch.allclose(O_triton, O_torch, atol=1e-3))

v1 vs v2 Summary

 v1 (element-wise)v2 (blocked)
1 thread handles1 elementBLOCK_M × BLOCK_N tile
Accumulatorscalar 0.0zeros(BLOCK_M, BLOCK_N)
K loop step1BLOCK_K
Multiply opw * htl.dot(w_block, h_block.T)
Memory tripsmany (repeated loads)few (chunk loaded once)
Good forlearning / small inputsreal workloads

v3 coming soon — block pointers (H100/H200/B200 style)

This post is licensed under CC BY 4.0 by the author.