Post

From 30ms to 2ms: My Wild Ride Optimizing GPU Kernels (And Why Memory is Actually Everything)

A journey through Triton, tiling, and the day I finally understood what 'memory-bound' really means

From 30ms to 2ms: My Wild Ride Optimizing GPU Kernels (And Why Memory is Actually Everything)

From 30ms to 2ms: My Wild Ride Optimizing GPU Kernels

Hey there! A few months ago, I thought I understood GPUs. I mean, I’d written some CUDA code, knew what threads and blocks were, and could throw around terms like “parallelism” at parties.

Then I tried to optimize a Multi-Head Latent Attention kernel, and everything I thought I knew… shattered.

This is the story of how a naive 313ms implementation became a 21ms kernel—not through magic, but through slowly, painfully learning what GPUs actually care about. Spoiler: it’s not about doing more math faster. It’s about not making the GPU wait for memory.

If you’ve ever stared at a “memory-bound” profiler report and thought “but… my GPU has thousands of cores!”, grab some coffee. This one’s for you.


The Beginning: When Your Kernel is Slower Than… Everything

Setting: My RTX 3050 laptop (yeah, we’re starting humble—a $300 GPU from 2022)
Task: Implement Multi-Head Latent Attention (MLA) from the DeepSeek-V3 paper
First attempt: 313.55 milliseconds per forward pass
PyTorch baseline: 10.41 milliseconds

Wait, what?

My “optimized” GPU kernel was 30x slower than just using PyTorch’s regular torch.matmul.

I remember sitting there, confused. The GPU has 2,560 CUDA cores. Why was it crawling?

The Moment That Changed Everything

I fired up NVIDIA’s Nsight Compute profiler (after googling “how to profile CUDA kernels” for the third time), and saw this:

1
2
3
Memory Throughput: 95% (Saturated) 🔴
Compute Throughput: <5% (Starved) 🔴
DRAM Excessive Reads: >60% 🔴

My brain: “What does that even mean?”

The profiler: “Your GPU is literally just sitting there waiting for memory. Those 2,560 cores? They’re idle 95% of the time, bored, twiddling their metaphorical thumbs.”

That was my first real lesson: A fast GPU with slow memory access is like hiring a Formula 1 driver to deliver pizza in rush hour traffic.


Detour: What Even IS Multi-Head Latent Attention?

Before we dive deeper, let me explain what I was trying to build (because honestly, understanding the problem helped me fix it later).

Traditional attention in transformers is memory-hungry. For every token, you store massive Key and Value matrices—like keeping a photo album where every picture is 4096 pixels wide.

MLA’s clever trick: compress first, decompress later.

The Standard (Memory-Exploding) Way:

  1. Down-project hidden states to compressed representations: h → c_KV (2048 dims → 512 dims)
  2. Up-project back to full size: c_KV → K, V (512 dims → 4096 dims) ← THIS KILLS MEMORY
  3. Apply RoPE (rotary position embeddings)
  4. Run attention

The problem? Step 2 writes GIGABYTES of intermediate matrices to global memory.

Think of it like this:

  • CPU approach: “I’ll unpack this compressed ZIP file, save it to disk, then read it back”
  • What we SHOULD do: “Just read from the ZIP directly when I need something”

That insight—that we could compute attention without materializing the decompressed tensors—was breakthrough #1.


My “Naive” Implementation: The One Where Everything Goes Wrong

My first kernel was embarrassingly simple. One thread per token:

1
2
3
4
5
6
7
8
9
10
11
@triton.jit
def naive_kernel(h_ptr, W_ptr, output_ptr, ...):
    token_id = tl.program_id(0)  # "I'm thread #42!"
    
    # Load this token's hidden state
    h = tl.load(h_ptr + token_id * stride)
    
    # Load the ENTIRE weight matrix (whoops)
    for col in range(output_dim):
        w = tl.load(W_ptr + col * stride)
        output[token_id, col] = dot_product(h, w)

Grid size: 65,536 threads (one per token)
What each thread did: Read the entire 2048×512 weight matrix from global memory

Picture this: 65,536 workers all trying to read the same giant book simultaneously. No sharing. No coordination. Just… chaos at the memory controller.

The GPU’s memory bandwidth (192 GB/s on my RTX 3050) was like a single-lane bridge with 65,000 cars trying to cross at once.

The Profiler Didn’t Lie

1
2
3
L2 Cache Hit Rate: 30%
Average Memory Latency: 450 cycles
Warp Stall Reasons: "Memory Dependency" (68%)

Translation: “Your threads are spending 450 clock cycles just… waiting. For. Memory.”

That hurt to see.


Breakthrough #1: The “Block Tiling” Epiphany

I was stuck. Then I remembered something from a random blog post about matrix multiplication:

“Don’t process one element at a time. Process blocks.”

It clicked while I was sketching on paper (yes, actual paper—sometimes the best debugging tool is a notebook).

Old way: Each thread reads weight matrix independently
New way: A block of threads shares one copy of the weight matrix

The Analogy That Made It Click

Imagine a classroom of 64 students (threads) taking an exam. Each question requires looking up a formula in a textbook.

Bad approach: Each student has their own textbook (65,536 books total)
Smart approach: The textbook is on the wall—all 64 students share it

In GPU terms:

  • “Their own textbook” = global memory (HBM, slow, far away)
  • “On the wall” = shared memory (SRAM, fast, on-chip)

The Code That Changed Everything

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@triton.jit
def tiled_kernel(h_ptr, W_ptr, c_KV_ptr, ...):
    # Process 64 tokens at once (not just 1!)
    block_id = tl.program_id(0)
    block_start = block_id * BLOCK_M  # BLOCK_M = 64
    
    # All 64 threads in this block cooperate
    token_ids = block_start + tl.arange(0, BLOCK_M)
    
    # Load weight tile into shared memory ONCE
    w_tile = tl.load(W_ptr[...], mask=valid)  # Cooperative load
    
    # Now all 64 threads reuse this tile
    for token in token_ids:
        h = tl.load(h_ptr[token])
        c_KV[token] = tl.dot(h, w_tile)  # Reusing shared data!

Before: 65,536 threads → 65,536 independent weight reads
After: 1,024 blocks → 1,024 weight reads (each shared by 64 threads)

Memory traffic reduction: 64x

The Numbers That Made Me Smile

After implementing block tiling:

MetricBeforeAfterChange
Kernel Time313.55 ms21.09 ms15x faster 🚀
L2 Cache Hit Rate30%82%+52%
Memory Stalls68%23%-45%

For the first time, my kernel wasn’t completely embarrassing.


Breakthrough #2: The RoPE Fusion Puzzle

Just when I thought I was done, I hit the next wall: RoPE (Rotary Position Embeddings).

RoPE is this clever trick where you rotate key/query vectors based on their position in the sequence. The math looks like this:

1
2
3
# Rotate pairs of elements
k_rotated[0] = k[0] * cos(θ) - k[1] * sin(θ)
k_rotated[1] = k[0] * sin(θ) + k[1] * cos(θ)

My initial implementation:

  1. Kernel 1: Compute k = h @ W_KR → write to memory
  2. Kernel 2: Read k from memory → apply RoPE → write back

This is like going to the bank to withdraw cash, walking home to count it, then walking back to deposit it. Pure inefficiency.

The “Why Not Both?” Moment

I realized: what if I just… fused them?

New pipeline:

  1. Compute k = h @ W_KR in registers
  2. Apply RoPE rotation while still in registers
  3. Write final result to memory

Eliminated: One full round-trip to global memory

But Triton Had Other Plans…

Here’s where things got weird. I tried the obvious approach:

1
2
3
4
5
6
k_even = k[:, 0::2]  # Get even indices (0, 2, 4, ...)
k_odd = k[:, 1::2]   # Get odd indices (1, 3, 5, ...)

# Apply RoPE rotation
rotated_even = k_even * cos - k_odd * sin
rotated_odd = k_even * sin + k_odd * cos

Triton’s response: CompileError: Cannot slice with step != 1

I spent an entire evening trying to convince the compiler to let me slice tensors. No luck.

The Pointer Arithmetic Hack

Then, at 2 AM (why do all breakthroughs happen at 2 AM?), I realized: don’t slice after loading. Load with the slice built-in.

Instead of fighting the compiler at the compute stage, I did slicing at the memory load stage using strided pointers:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Load even rows: row 0, 2, 4, 6...
even_ptr = tl.make_block_ptr(
    base=W_KR_ptr,
    shape=(d_c // 2, D),
    strides=(row_stride * 2, col_stride),  # Double the row stride!
    offsets=(0, k * BLOCK_K),
    block_shape=(BLOCK_N, BLOCK_K),
)

# Load odd rows: row 1, 3, 5, 7...
odd_ptr = tl.make_block_ptr(
    base=W_KR_ptr + row_stride,  # Start one row later
    shape=(d_c // 2, D),
    strides=(row_stride * 2, col_stride),  # Still double stride
    offsets=(0, k * BLOCK_K),
    block_shape=(BLOCK_N, BLOCK_K),
)

w_even = tl.load(even_ptr)
w_odd = tl.load(odd_ptr)

The trick: By setting stride to 2 * normal_stride, the GPU’s memory controller automatically skips every other row. No compute-stage slicing needed!

This is like telling a librarian “get me every even-numbered book” instead of asking for all books and filtering yourself.

The Result

Fused RoPE kernel: 22.30ms (vs 21.09ms without RoPE)
Separate kernels would’ve been: ~35ms

Saved ~13ms just by keeping data in registers!


The H100 Plot Twist: When “Fast Hardware” Isn’t Enough

After weeks of optimization on my RTX 3050, I got access to an H100 through Modal.

My expectation: “This has 3.35TB/s memory bandwidth (17x my laptop!). My kernel will SCREAM.”

Reality: It was slower than PyTorch. Again.

1
2
PyTorch (cuBLAS): 1.40 ms
My Kernel: 2.37 ms

Wait, WHAT? After all that work?

The Humbling Reality of Auto-Tuning

The problem: I’d hardcoded block sizes optimized for my RTX 3050 (Ampere architecture). The H100 is Hopper architecture with:

  • Bigger caches
  • More SMs (Streaming Multiprocessors)
  • Tensor Memory Accelerator (TMA)
  • Different optimal tile sizes

My kernel was like bringing a bicycle to a Formula 1 race—technically it works, but you’re not using the hardware’s potential.

Enter: The Autotuner

Triton has this magical thing called @triton.autotune that tries different configurations and picks the fastest:

1
2
3
4
5
6
7
8
9
10
11
12
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=5),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4),
        # ... 12 more configs
    ],
    key=['M', 'D', 'd_c'],  # Autoselect based on problem size
)
@triton.jit
def flash_mla_kernel(...):
    # Same kernel code

The autotuner runs every configuration during the first call, times them all, and caches the winner.

Results after autotuning:

1
2
Kernel Time: 2.37ms → 1.54ms (35% faster!)
Peak Memory: 368 MB → 288 MB (21.7% savings)

Still not beating cuBLAS (1.40ms), but way closer. And we’re using less memory, which means:

  • Bigger batch sizes possible
  • Longer sequences fit
  • Better multi-tenancy in production

The Memory Win That Mattered

While chasing speed, I accidentally discovered something valuable:

80 MB saved = 28% larger batch size

In production LLM serving, this is HUGE. It’s the difference between serving 100 requests/second vs 128 requests/second on the same hardware.

Sometimes “almost as fast but more memory-efficient” is more valuable than “slightly faster.”


The Bugs That Almost Broke Me

Bug #1: The Invisible tl

1
2
3
@triton.jit
def my_kernel(...):
    x = tl.arange(0, 64)  # NameError: 'tl' is not defined

Wait, what? I imported it at the top!

The problem: Triton’s JIT compiler runs in a separate namespace. It couldn’t see my imports.

The solution: Force it into the global namespace before compilation:

1
2
3
4
5
6
7
import triton.language as tl

globals()['tl'] = tl  # Make it visible to JIT

@triton.jit
def my_kernel(...):
    x = tl.arange(0, 64)  # Now it works!

This took me 3 hours to debug. Three. Hours.

Bug #2: The Stride Catastrophe

1
RuntimeError: CUDA error: an illegal memory access was encountered

The kernel was crashing randomly on the H100 but worked fine on my laptop.

After adding bounds checks, print statements, and borderline crying, I found it:

1
2
3
4
5
6
# ❌ WRONG
h.stride(0)  # Returns stride of 3D tensor

# ✅ CORRECT
h_flat = h.view(B * T, D)  # Flatten first
h_flat.stride(0)  # Now returns correct 2D stride

I was passing 3D tensor strides to a kernel expecting 2D data. The H100 has stricter memory alignment, so it caught the bug my RTX 3050 was silently ignoring.

Lesson learned: ALWAYS flatten your tensors before passing to kernels. Always.

Bug #3: Shared Memory Exhaustion

1
OutOfResourcesError: out of resource: shared memory

Some autotuner configs tried to allocate 280KB of shared memory per block. The H100 only has 227KB per SM.

The fix: Let the autotuner fail gracefully:

1
2
3
4
5
6
7
8
@triton.autotune(
    configs=[...],
    key=['M', 'D', 'd_c'],
    reset_to_zero=['c_KV_ptr'],
    restore_value=['h_ptr'],
    warmup=25,
    rep=100
)

The autotuner automatically skips configs that fail to compile. Crisis averted.


The Numbers: Before and After

Development Journey (RTX 3050)

StageImplementationTime (ms)Improvement
StartNaive PyTorch10.41Baseline
Attempt 1Naive Triton313.55-30x (ouch)
Optimization 1Block Tiling21.09+14.8x 🎉
Optimization 2Fused RoPE22.30(small regression)
Final TuningOptimized RoPE21.090.5x vs cuBLAS

Production (H100)

MetricPyTorchTriton (Optimized)Improvement
Latency1.397 ms1.540 ms0.91x speed
Peak Memory368 MB288 MB-21.7% 🎯
Memory Saved-80 MB+28% batch capacity

What This Means in Production

With 80MB savings per forward pass:

  • Before: Max batch size 64
  • After: Max batch size 82 (+28%)
  • Throughput: 100 req/s → 128 req/s

What I Learned (The Real Takeaways)

1. Memory is Everything

On modern GPUs, compute is cheap. Memory movement is expensive.

A kernel that does 2x more math but reads memory once will beat a kernel that does 1x math but reads memory three times.

The profiler is your best friend. If “Memory Throughput” is >80% and “Compute Throughput” is <40%, you have a memory problem, not a compute problem.

2. Block Tiling is Not Optional

For any real workload, you must use tiling to:

  • Amortize memory reads across multiple outputs
  • Maximize cache reuse
  • Keep the GPU busy while waiting for memory

The pattern is always:

1
2
3
Load data tile into shared memory (once)
→ Compute many results using that tile (many times)
→ Repeat

3. Fusion Beats Pipelining

Two separate kernel launches = two round-trips to global memory.

Fusing operations means keeping intermediate results in registers/shared memory.

Even if fusion makes the code more complex, the memory savings usually win.

4. Hardware Matters More Than You Think

Code optimized for Ampere (RTX 30-series) might run poorly on Hopper (H100).

Different architectures have:

  • Different cache sizes
  • Different optimal block sizes
  • Different memory hierarchies
  • Different instruction sets

Use auto-tuning. Let the hardware tell you what it wants.

5. “Good Enough” Can Be Better Than “Perfect”

My kernel is 10% slower than cuBLAS but uses 22% less memory.

For production LLM serving where memory is the bottleneck, this is a win.

Don’t obsess over matching cuBLAS—focus on the metric that actually matters for your use case.


The Tools That Saved My Sanity

Development

  • Google Colab Pro: Free T4 GPUs for experimentation
  • Modal: On-demand H100 access ($2.50/hour—way cheaper than AWS)
  • Triton Autotune: Automatically finds optimal configs

Debugging

  • NVIDIA Nsight Compute: Shows exactly why your kernel is slow
  • torch.cuda.memory_summary(): Tracks memory allocations
  • PyTorch Profiler: Identifies bottlenecks in end-to-end training

Learning Resources

  • “Programming Massively Parallel Processors” (Kirk & Hwu)
  • Triton documentation + tutorials
  • Lei Mao’s blog on GPU optimization
  • Papers: Flash Attention, FlashAttention-2, DeepSeek-V3

What’s Next: The Unfinished Journey

This project isn’t “done”—it’s just… paused at a good checkpoint.

Near-Term Goals

  • FP8 precision: H100’s Tensor Cores can do 2x throughput with e4m3 format
  • Single-tile optimization: Read input tensors exactly once
  • Better RoPE fusion: Shave off that extra 2ms overhead

Moon Shots

  • Full MLA attention kernel: Fuse query projection + attention computation
  • Multi-GPU: Tensor/pipeline parallelism for really big models
  • vLLM integration: Make this production-ready for real serving
  • TMA (Tensor Memory Accelerator): Use H100’s async memory engine

The Real Goal

I want to close the gap from 0.91x to 1.0x+ of cuBLAS performance while keeping the memory advantage.

If I can beat NVIDIA’s hand-tuned kernels at their own game—even by 1%—that would be the ultimate validation of everything I learned.


Closing Thoughts: Why This Journey Mattered

Six months ago, I didn’t know what “memory-bound” meant. I thought GPUs were just “CPUs with more cores.”

Now I can:

  • Read profiler output and immediately spot bottlenecks
  • Design memory access patterns that minimize traffic
  • Debug pointer arithmetic in assembly-level IR
  • Reason about cache hierarchies and memory coalescing

But more importantly, I learned how to learn hard things:

  1. Start simple (even if it’s embarrassingly slow)
  2. Profile obsessively (data > intuition)
  3. Fix one thing at a time (no cowboy coding)
  4. Celebrate small wins (15x is huge, even if it’s not 30x)
  5. Share the journey (you’re reading this, so mission accomplished!)

For Anyone Starting This Journey

If you’re where I was six months ago—confused by terms like “occupancy” and “register pressure”—here’s my advice:

Start with the slowest possible kernel. Make it work. Then make it fast.

Don’t try to write the perfect kernel on day one. You’ll get overwhelmed and quit.

Instead:

  1. Write naive code that gives correct results (even if it’s 10x slower)
  2. Profile it (Nsight Compute is free!)
  3. Fix the biggest bottleneck (usually memory)
  4. Repeat until you hit diminishing returns
  5. Ship it

The best code is code that ships. The best optimization is the one you actually finish.


Want to Dive Deeper?

The full code (including all the bugs I hit and fixed) is on GitHub: [Your Repo Link]

Key files:

  • benchmark_mla.py - RTX 3050 experiments
  • fast.py - Final optimized kernel with H100 deployment
  • test_fast.py - Correctness tests and ablations
  • results/ - Nsight Compute profiles and pretty charts

If you have questions, war stories, or just want to geek out about GPU kernels:

  • Open an issue on the repo
  • Find me on LinkedIn
  • Tweet at me (I’m probably already talking about GPUs)

Acknowledgments

People who kept me sane:

  • The Triton team at OpenAI (especially the Discord community)
  • Modal team for H100 access and support
  • My laptop’s RTX 3050 (you did your best, buddy)
  • Coffee (so much coffee)

Resources that changed everything:

  • PMPP textbook (Kirk & Hwu) - Chapter 5 specifically
  • Lei Mao’s GPU programming blog
  • Flash Attention paper (Tri Dao et al.)
  • Random Stack Overflow answers at 3 AM

And you, for reading this whole thing. Seriously. Thanks for caring about the messy middle part of learning, not just the polished final results.


Final Stats (For the Resume)

What I built:

  • High-performance GPU kernel using OpenAI Triton
  • 15x speedup on consumer hardware (RTX 3050)
  • 21.7% memory reduction on production hardware (H100)
  • Full development lifecycle from prototype to production

What I learned:

  • Memory hierarchy optimization (HBM → L2 → Shared → Registers)
  • Block tiling and memory coalescing techniques
  • Kernel fusion strategies
  • Production debugging on real enterprise hardware
  • How to read profiler output and make data-driven optimizations

⭐ Star the repo if this helped you understand GPU optimization better!

💬 Issues/PRs welcome—I’m still learning too!


Made with 🔥 (and way too much debugging)
by someone who finally gets why memory bandwidth matters


P.S. - If you’re hiring GPU optimization engineers, my DMs are open 👀

This post is licensed under CC BY 4.0 by the author.