My Wild Ride Optimizing GPU Kernels (And Why Memory is Actually Everything)
A journey through Triton, tiling, and the day I finally understood what 'memory-bound' really means
Triton Matrix Multiplication — Element-wise Parallel Approach
Problem Statement
Given two matrices:
Wof shape(2, 4)Hof shape(3, 4)
Goal: Compute O = H @ W.T → output shape (3, 2)
Since W and H share the same column dimension (K=4), we transpose W so the multiplication works: H (3×4) × W^T (4×2) = O (3×2)
Core Insight — How I Thought About Parallelism
The key observation from the diagram:
For each value of k, the column of W and column of H form an outer product:
1
2
3
4
5
6
7
8
9
10
11
12
k=0:
W_col = [w00, w10] H_col = [h00, h10, h20]
Outer product → 6 values computed in parallel:
h00 h10 h20
w00 [0,0] [0,1] [0,2] → 3 threads
w10 [1,0] [1,1] [1,2] → 3 threads
k=1: W_col = [w01, w11], H_col = [h01, h11, h21] → same 6 threads
k=2: W_col = [w02, w12], H_col = [h02, h12, h22] → same 6 threads
k=3: W_col = [w03, w13], H_col = [h03, h13, h23] → same 6 threads
K loop is sequential. Within each K, all 6 threads run in parallel.
1
2
3
4
K=0 → [6 threads parallel] → done
K=1 → [6 threads parallel] → done
K=2 → [6 threads parallel] → done
K=3 → [6 threads parallel] → done
Each thread only knows its own (i, j) position and independently accumulates its result across K iterations — no thread waits for another.
Mental Model
| Concept | Meaning |
|---|---|
program_id(0) = i | Which row of W (0 or 1) |
program_id(1) = j | Which row of H (0, 1, or 2) |
grid = (M, N) | Total threads launched = M×N = 6 |
stride | How many memory positions to jump to get next element |
acc | Each thread’s private accumulator across K loop |
Why Stride?
Matrices are stored as 1D in memory:
1
2
3
4
5
W = | w00 w01 w02 w03 | → [w00, w01, w02, w03, w10, w11, w12, w13]
| w10 w11 w12 w13 | 0 1 2 3 4 5 6 7
stride_row = 4 (jump 4 to go to next row)
stride_col = 1 (jump 1 to go to next column)
pid gives the dimension index. Multiplying by stride gives the actual memory address:
1
W[i, k] → W_ptr + i * stride_row + k * stride_col
Triton Kernel (v1 — Element-wise, No Tiling)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import triton
import triton.language as tl
import torch
@triton.jit
def matmul_kernel(
W_ptr, H_ptr, O_ptr,
K,
stride_wi, stride_wk,
stride_hi, stride_hk,
stride_oi, stride_oj,
):
# Each thread gets its own (i, j) — its output cell
i = tl.program_id(0) # row index in W (0 or 1)
j = tl.program_id(1) # row index in H (0, 1, or 2)
acc = tl.zeros((), dtype=tl.float32)
# K loop — sequential, but each thread runs its own independently
for k in range(K):
w = tl.load(W_ptr + i * stride_wi + k * stride_wk)
h = tl.load(H_ptr + j * stride_hi + k * stride_hk)
acc += w * h
# Store result in O[i, j]
tl.store(O_ptr + i * stride_oi + j * stride_oj, acc)
def run(W, H):
M = W.shape[0] # 2
N = H.shape[0] # 3
K = W.shape[1] # 4
O = torch.zeros(M, N, device=W.device, dtype=torch.float32)
# Grid = 2×3 = 6 threads total — one per output element
grid = (M, N)
matmul_kernel[grid](
W, H, O,
K,
W.stride(0), W.stride(1),
H.stride(0), H.stride(1),
O.stride(0), O.stride(1),
)
return O
# Test
if __name__ == "__main__":
W = torch.randn(2, 4, device='cuda')
H = torch.randn(3, 4, device='cuda')
O_triton = run(W, H)
O_torch = H @ W.T
print("Match:", torch.allclose(O_triton, O_torch, atol=1e-3))
Limitations of This Version
- One thread per output element → not efficient for large matrices
- No memory reuse — same data loaded multiple times
- Not using GPU’s full memory bandwidth potential
What’s Next — v2 (Tiled / Blocked Version)
Instead of 1 thread per element, we will assign 1 thread per tile (a chunk of the output matrix). This allows:
- Loading a block of W and H into shared memory once
- All threads in the block reusing that data
- Much better memory efficiency on real hardware (RTX 3050, etc.)
The K loop will also be tiled — instead of K=0,1,2,3 one by one, we’ll process BLOCK_K elements at a time using tl.dot().
v2 coming soon.