From 30ms to 2ms: My Wild Ride Optimizing GPU Kernels (And Why Memory is Actually Everything)
A journey through Triton, tiling, and the day I finally understood what 'memory-bound' really means
Triton Matrix Multiplication — Element-wise Parallel Approach
Problem Statement
Given two matrices:
Wof shape(2, 4)Hof shape(3, 4)
Goal: Compute O = H @ W.T → output shape (3, 2)
Since W and H share the same column dimension (K=4), we transpose W so the multiplication works: H (3×4) × W^T (4×2) = O (3×2)
Core Insight — How I Thought About Parallelism
The key observation from the diagram:
For each value of k, the column of W and column of H form an outer product:
1
2
3
4
5
6
7
8
9
10
11
12
k=0:
W_col = [w00, w10] H_col = [h00, h10, h20]
Outer product → 6 values computed in parallel:
h00 h10 h20
w00 [0,0] [0,1] [0,2] → 3 threads
w10 [1,0] [1,1] [1,2] → 3 threads
k=1: W_col = [w01, w11], H_col = [h01, h11, h21] → same 6 threads
k=2: W_col = [w02, w12], H_col = [h02, h12, h22] → same 6 threads
k=3: W_col = [w03, w13], H_col = [h03, h13, h23] → same 6 threads
K loop is sequential. Within each K, all 6 threads run in parallel.
1
2
3
4
K=0 → [6 threads parallel] → done
K=1 → [6 threads parallel] → done
K=2 → [6 threads parallel] → done
K=3 → [6 threads parallel] → done
Each thread only knows its own (i, j) position and independently accumulates its result across K iterations — no thread waits for another.
Mental Model
| Concept | Meaning |
|---|---|
program_id(0) = i | Which row of W (0 or 1) |
program_id(1) = j | Which row of H (0, 1, or 2) |
grid = (M, N) | Total threads launched = M×N = 6 |
stride | How many memory positions to jump to get next element |
acc | Each thread’s private accumulator across K loop |
Why Stride?
Matrices are stored as 1D in memory:
1
2
3
4
5
W = | w00 w01 w02 w03 | → [w00, w01, w02, w03, w10, w11, w12, w13]
| w10 w11 w12 w13 | 0 1 2 3 4 5 6 7
stride_row = 4 (jump 4 to go to next row)
stride_col = 1 (jump 1 to go to next column)
pid gives the dimension index. Multiplying by stride gives the actual memory address:
1
W[i, k] → W_ptr + i * stride_row + k * stride_col
Triton Kernel (v1 — Element-wise, No Tiling)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import triton
import triton.language as tl
import torch
@triton.jit
def matmul_kernel(
W_ptr, H_ptr, O_ptr,
K,
stride_wi, stride_wk,
stride_hi, stride_hk,
stride_oi, stride_oj,
):
# Each thread gets its own (i, j) — its output cell
i = tl.program_id(0) # row index in W (0 or 1)
j = tl.program_id(1) # row index in H (0, 1, or 2)
acc = tl.zeros((), dtype=tl.float32)
# K loop — sequential, but each thread runs its own independently
for k in range(K):
w = tl.load(W_ptr + i * stride_wi + k * stride_wk)
h = tl.load(H_ptr + j * stride_hi + k * stride_hk)
acc += w * h
# Store result in O[i, j]
tl.store(O_ptr + i * stride_oi + j * stride_oj, acc)
def run(W, H):
M = W.shape[0] # 2
N = H.shape[0] # 3
K = W.shape[1] # 4
O = torch.zeros(M, N, device=W.device, dtype=torch.float32)
# Grid = 2×3 = 6 threads total — one per output element
grid = (M, N)
matmul_kernel[grid](
W, H, O,
K,
W.stride(0), W.stride(1),
H.stride(0), H.stride(1),
O.stride(0), O.stride(1),
)
return O
# Test
if __name__ == "__main__":
W = torch.randn(2, 4, device='cuda')
H = torch.randn(3, 4, device='cuda')
O_triton = run(W, H)
O_torch = H @ W.T
print("Match:", torch.allclose(O_triton, O_torch, atol=1e-3))
Limitations of This Version
- One thread per output element → not efficient for large matrices
- No memory reuse — same data loaded multiple times
- Not using GPU’s full memory bandwidth potential
What’s Next — v2 (Tiled / Blocked Version)
Instead of 1 thread per element, we will assign 1 thread per tile (a chunk of the output matrix). This allows:
- Loading a block of W and H into shared memory once
- All threads in the block reusing that data
- Much better memory efficiency on real hardware (RTX 3050, etc.)
The K loop will also be tiled — instead of K=0,1,2,3 one by one, we’ll process BLOCK_K elements at a time using tl.dot().
v2 coming soon.
v2 — Tiled / Blocked Version
Why Blocking?
In v1, every thread loads its own slice of W and H independently:
1
2
3
4
Thread(0,0): loads W row 0 — 1024 times
Thread(0,1): loads W row 0 — 1024 times AGAIN
Thread(0,2): loads W row 0 — 1024 times AGAIN
→ Same data pulled from memory over and over. Slow!
Block idea: One thread handles a whole tile. Load W chunk and H chunk once, compute the full tile output — way fewer memory trips.
1
2
v1: 1 thread → 1 element
v2: 1 thread → BLOCK_M × BLOCK_N elements (a tile)
New Concepts
tl.arange — get a range of indices at once:
1
2
3
4
5
6
# v1:
i = tl.program_id(0) # just one number
# v2:
pid_m = tl.program_id(0)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # array of indices [0,1] or [2,3]...
[:, None] and [None, :] — build 2D address grid:
1
2
3
4
5
offs_m[:, None] # shape (BLOCK_M, 1) — column
offs_k[None, :] # shape (1, BLOCK_K) — row
# Together via broadcasting → full 2D grid of addresses
offs_m[:, None] + offs_k[None, :] # shape (BLOCK_M, BLOCK_K)
tl.dot — block multiply:
1
2
3
4
5
# v1: scalar × scalar
acc += w * h
# v2: tile × tile
acc += tl.dot(w_block, h_block.T) # (BLOCK_M, BLOCK_K) × (BLOCK_K, BLOCK_N)
Kernel
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@triton.jit
def matmul_kernel_blocked(
W_ptr, H_ptr, O_ptr,
M, N, K,
stride_wi, stride_wk,
stride_hi, stride_hk,
stride_oi, stride_oj,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0) # tile row — belongs to W
pid_n = tl.program_id(1) # tile col — belongs to H
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# Accumulator is now a tile, not a scalar
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# K loop — now in chunks of BLOCK_K
for k in range(0, K, BLOCK_K):
offs_k = k + tl.arange(0, BLOCK_K)
# Load W tile: shape (BLOCK_M, BLOCK_K)
w_ptrs = W_ptr + offs_m[:, None] * stride_wi + offs_k[None, :] * stride_wk
w_block = tl.load(w_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)
# Load H tile: shape (BLOCK_N, BLOCK_K)
h_ptrs = H_ptr + offs_n[:, None] * stride_hi + offs_k[None, :] * stride_hk
h_block = tl.load(h_ptrs, mask=(offs_n[:, None] < N) & (offs_k[None, :] < K), other=0.0)
# Block multiply — outer product idea, now at tile level!
acc += tl.dot(w_block, h_block.T)
# Store entire tile
o_ptrs = O_ptr + offs_m[:, None] * stride_oi + offs_n[None, :] * stride_oj
tl.store(o_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
def run_blocked(W, H, BLOCK_M=32, BLOCK_N=32, BLOCK_K=32):
M, K = W.shape
N = H.shape[0]
O = torch.zeros(M, N, device=W.device, dtype=torch.float32)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
matmul_kernel_blocked[grid](
W, H, O,
M, N, K,
W.stride(0), W.stride(1),
H.stride(0), H.stride(1),
O.stride(0), O.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
)
return O
# Test
if __name__ == "__main__":
W = torch.randn(512, 256, device='cuda')
H = torch.randn(512, 256, device='cuda')
O_triton = run_blocked(W, H)
O_torch = H @ W.T
print("Match:", torch.allclose(O_triton, O_torch, atol=1e-3))
v1 vs v2 Summary
| v1 (element-wise) | v2 (blocked) | |
|---|---|---|
| 1 thread handles | 1 element | BLOCK_M × BLOCK_N tile |
| Accumulator | scalar 0.0 | zeros(BLOCK_M, BLOCK_N) |
| K loop step | 1 | BLOCK_K |
| Multiply op | w * h | tl.dot(w_block, h_block.T) |
| Memory trips | many (repeated loads) | few (chunk loaded once) |
| Good for | learning / small inputs | real workloads |
v3 coming soon — block pointers (H100/H200/B200 style)