Post

My Wild Ride Optimizing GPU Kernels (And Why Memory is Actually Everything)

A journey through Triton, tiling, and the day I finally understood what 'memory-bound' really means

My Wild Ride Optimizing GPU Kernels (And Why Memory is Actually Everything)

Triton Matrix Multiplication — Element-wise Parallel Approach

Problem Statement

Given two matrices:

  • W of shape (2, 4)
  • H of shape (3, 4)

Goal: Compute O = H @ W.T → output shape (3, 2)

Since W and H share the same column dimension (K=4), we transpose W so the multiplication works: H (3×4) × W^T (4×2) = O (3×2)


Core Insight — How I Thought About Parallelism

The key observation from the diagram:

For each value of k, the column of W and column of H form an outer product:

1
2
3
4
5
6
7
8
9
10
11
12
k=0:
  W_col = [w00, w10]       H_col = [h00, h10, h20]

  Outer product → 6 values computed in parallel:

       h00    h10    h20
  w00  [0,0]  [0,1]  [0,2]   → 3 threads
  w10  [1,0]  [1,1]  [1,2]   → 3 threads

k=1: W_col = [w01, w11],  H_col = [h01, h11, h21]  → same 6 threads
k=2: W_col = [w02, w12],  H_col = [h02, h12, h22]  → same 6 threads
k=3: W_col = [w03, w13],  H_col = [h03, h13, h23]  → same 6 threads

K loop is sequential. Within each K, all 6 threads run in parallel.

1
2
3
4
K=0 → [6 threads parallel] → done
K=1 → [6 threads parallel] → done
K=2 → [6 threads parallel] → done
K=3 → [6 threads parallel] → done

Each thread only knows its own (i, j) position and independently accumulates its result across K iterations — no thread waits for another.


Mental Model

ConceptMeaning
program_id(0) = iWhich row of W (0 or 1)
program_id(1) = jWhich row of H (0, 1, or 2)
grid = (M, N)Total threads launched = M×N = 6
strideHow many memory positions to jump to get next element
accEach thread’s private accumulator across K loop

Why Stride?

Matrices are stored as 1D in memory:

1
2
3
4
5
W = | w00  w01  w02  w03 |   →  [w00, w01, w02, w03, w10, w11, w12, w13]
    | w10  w11  w12  w13 |        0    1    2    3    4    5    6    7

stride_row = 4   (jump 4 to go to next row)
stride_col = 1   (jump 1 to go to next column)

pid gives the dimension index. Multiplying by stride gives the actual memory address:

1
W[i, k]  →  W_ptr + i * stride_row + k * stride_col

Triton Kernel (v1 — Element-wise, No Tiling)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import triton
import triton.language as tl
import torch

@triton.jit
def matmul_kernel(
    W_ptr, H_ptr, O_ptr,
    K,
    stride_wi, stride_wk,
    stride_hi, stride_hk,
    stride_oi, stride_oj,
):
    # Each thread gets its own (i, j) — its output cell
    i = tl.program_id(0)   # row index in W  (0 or 1)
    j = tl.program_id(1)   # row index in H  (0, 1, or 2)

    acc = tl.zeros((), dtype=tl.float32)

    # K loop — sequential, but each thread runs its own independently
    for k in range(K):
        w = tl.load(W_ptr + i * stride_wi + k * stride_wk)
        h = tl.load(H_ptr + j * stride_hi + k * stride_hk)
        acc += w * h

    # Store result in O[i, j]
    tl.store(O_ptr + i * stride_oi + j * stride_oj, acc)


def run(W, H):
    M = W.shape[0]   # 2
    N = H.shape[0]   # 3
    K = W.shape[1]   # 4

    O = torch.zeros(M, N, device=W.device, dtype=torch.float32)

    # Grid = 2×3 = 6 threads total — one per output element
    grid = (M, N)

    matmul_kernel[grid](
        W, H, O,
        K,
        W.stride(0), W.stride(1),
        H.stride(0), H.stride(1),
        O.stride(0), O.stride(1),
    )
    return O


# Test
if __name__ == "__main__":
    W = torch.randn(2, 4, device='cuda')
    H = torch.randn(3, 4, device='cuda')

    O_triton = run(W, H)
    O_torch  = H @ W.T

    print("Match:", torch.allclose(O_triton, O_torch, atol=1e-3))

Limitations of This Version

  • One thread per output element → not efficient for large matrices
  • No memory reuse — same data loaded multiple times
  • Not using GPU’s full memory bandwidth potential

What’s Next — v2 (Tiled / Blocked Version)

Instead of 1 thread per element, we will assign 1 thread per tile (a chunk of the output matrix). This allows:

  • Loading a block of W and H into shared memory once
  • All threads in the block reusing that data
  • Much better memory efficiency on real hardware (RTX 3050, etc.)

The K loop will also be tiled — instead of K=0,1,2,3 one by one, we’ll process BLOCK_K elements at a time using tl.dot().

v2 coming soon.

This post is licensed under CC BY 4.0 by the author.