From 30ms to 2ms: My Wild Ride Optimizing GPU Kernels (And Why Memory is Actually Everything)
A journey through Triton, tiling, and the day I finally understood what 'memory-bound' really means
From 30ms to 2ms: My Wild Ride Optimizing GPU Kernels
Hey there! A few months ago, I thought I understood GPUs. I mean, I’d written some CUDA code, knew what threads and blocks were, and could throw around terms like “parallelism” at parties.
Then I tried to optimize a Multi-Head Latent Attention kernel, and everything I thought I knew… shattered.
This is the story of how a naive 313ms implementation became a 21ms kernel—not through magic, but through slowly, painfully learning what GPUs actually care about. Spoiler: it’s not about doing more math faster. It’s about not making the GPU wait for memory.
If you’ve ever stared at a “memory-bound” profiler report and thought “but… my GPU has thousands of cores!”, grab some coffee. This one’s for you.
The Beginning: When Your Kernel is Slower Than… Everything
Setting: My RTX 3050 laptop (yeah, we’re starting humble—a $300 GPU from 2022)
Task: Implement Multi-Head Latent Attention (MLA) from the DeepSeek-V3 paper
First attempt: 313.55 milliseconds per forward pass
PyTorch baseline: 10.41 milliseconds
Wait, what?
My “optimized” GPU kernel was 30x slower than just using PyTorch’s regular torch.matmul.
I remember sitting there, confused. The GPU has 2,560 CUDA cores. Why was it crawling?
The Moment That Changed Everything
I fired up NVIDIA’s Nsight Compute profiler (after googling “how to profile CUDA kernels” for the third time), and saw this:
1
2
3
Memory Throughput: 95% (Saturated) 🔴
Compute Throughput: <5% (Starved) 🔴
DRAM Excessive Reads: >60% 🔴
My brain: “What does that even mean?”
The profiler: “Your GPU is literally just sitting there waiting for memory. Those 2,560 cores? They’re idle 95% of the time, bored, twiddling their metaphorical thumbs.”
That was my first real lesson: A fast GPU with slow memory access is like hiring a Formula 1 driver to deliver pizza in rush hour traffic.
Detour: What Even IS Multi-Head Latent Attention?
Before we dive deeper, let me explain what I was trying to build (because honestly, understanding the problem helped me fix it later).
Traditional attention in transformers is memory-hungry. For every token, you store massive Key and Value matrices—like keeping a photo album where every picture is 4096 pixels wide.
MLA’s clever trick: compress first, decompress later.
The Standard (Memory-Exploding) Way:
- Down-project hidden states to compressed representations:
h → c_KV(2048 dims → 512 dims) - Up-project back to full size:
c_KV → K, V(512 dims → 4096 dims) ← THIS KILLS MEMORY - Apply RoPE (rotary position embeddings)
- Run attention
The problem? Step 2 writes GIGABYTES of intermediate matrices to global memory.
Think of it like this:
- CPU approach: “I’ll unpack this compressed ZIP file, save it to disk, then read it back”
- What we SHOULD do: “Just read from the ZIP directly when I need something”
That insight—that we could compute attention without materializing the decompressed tensors—was breakthrough #1.
My “Naive” Implementation: The One Where Everything Goes Wrong
My first kernel was embarrassingly simple. One thread per token:
1
2
3
4
5
6
7
8
9
10
11
@triton.jit
def naive_kernel(h_ptr, W_ptr, output_ptr, ...):
token_id = tl.program_id(0) # "I'm thread #42!"
# Load this token's hidden state
h = tl.load(h_ptr + token_id * stride)
# Load the ENTIRE weight matrix (whoops)
for col in range(output_dim):
w = tl.load(W_ptr + col * stride)
output[token_id, col] = dot_product(h, w)
Grid size: 65,536 threads (one per token)
What each thread did: Read the entire 2048×512 weight matrix from global memory
Picture this: 65,536 workers all trying to read the same giant book simultaneously. No sharing. No coordination. Just… chaos at the memory controller.
The GPU’s memory bandwidth (192 GB/s on my RTX 3050) was like a single-lane bridge with 65,000 cars trying to cross at once.
The Profiler Didn’t Lie
1
2
3
L2 Cache Hit Rate: 30%
Average Memory Latency: 450 cycles
Warp Stall Reasons: "Memory Dependency" (68%)
Translation: “Your threads are spending 450 clock cycles just… waiting. For. Memory.”
That hurt to see.
Breakthrough #1: The “Block Tiling” Epiphany
I was stuck. Then I remembered something from a random blog post about matrix multiplication:
“Don’t process one element at a time. Process blocks.”
It clicked while I was sketching on paper (yes, actual paper—sometimes the best debugging tool is a notebook).
Old way: Each thread reads weight matrix independently
New way: A block of threads shares one copy of the weight matrix
The Analogy That Made It Click
Imagine a classroom of 64 students (threads) taking an exam. Each question requires looking up a formula in a textbook.
Bad approach: Each student has their own textbook (65,536 books total)
Smart approach: The textbook is on the wall—all 64 students share it
In GPU terms:
- “Their own textbook” = global memory (HBM, slow, far away)
- “On the wall” = shared memory (SRAM, fast, on-chip)
The Code That Changed Everything
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@triton.jit
def tiled_kernel(h_ptr, W_ptr, c_KV_ptr, ...):
# Process 64 tokens at once (not just 1!)
block_id = tl.program_id(0)
block_start = block_id * BLOCK_M # BLOCK_M = 64
# All 64 threads in this block cooperate
token_ids = block_start + tl.arange(0, BLOCK_M)
# Load weight tile into shared memory ONCE
w_tile = tl.load(W_ptr[...], mask=valid) # Cooperative load
# Now all 64 threads reuse this tile
for token in token_ids:
h = tl.load(h_ptr[token])
c_KV[token] = tl.dot(h, w_tile) # Reusing shared data!
Before: 65,536 threads → 65,536 independent weight reads
After: 1,024 blocks → 1,024 weight reads (each shared by 64 threads)
Memory traffic reduction: 64x
The Numbers That Made Me Smile
After implementing block tiling:
| Metric | Before | After | Change |
|---|---|---|---|
| Kernel Time | 313.55 ms | 21.09 ms | 15x faster 🚀 |
| L2 Cache Hit Rate | 30% | 82% | +52% |
| Memory Stalls | 68% | 23% | -45% |
For the first time, my kernel wasn’t completely embarrassing.
Breakthrough #2: The RoPE Fusion Puzzle
Just when I thought I was done, I hit the next wall: RoPE (Rotary Position Embeddings).
RoPE is this clever trick where you rotate key/query vectors based on their position in the sequence. The math looks like this:
1
2
3
# Rotate pairs of elements
k_rotated[0] = k[0] * cos(θ) - k[1] * sin(θ)
k_rotated[1] = k[0] * sin(θ) + k[1] * cos(θ)
My initial implementation:
- Kernel 1: Compute
k = h @ W_KR→ write to memory - Kernel 2: Read
kfrom memory → apply RoPE → write back
This is like going to the bank to withdraw cash, walking home to count it, then walking back to deposit it. Pure inefficiency.
The “Why Not Both?” Moment
I realized: what if I just… fused them?
New pipeline:
- Compute
k = h @ W_KRin registers - Apply RoPE rotation while still in registers
- Write final result to memory
Eliminated: One full round-trip to global memory
But Triton Had Other Plans…
Here’s where things got weird. I tried the obvious approach:
1
2
3
4
5
6
k_even = k[:, 0::2] # Get even indices (0, 2, 4, ...)
k_odd = k[:, 1::2] # Get odd indices (1, 3, 5, ...)
# Apply RoPE rotation
rotated_even = k_even * cos - k_odd * sin
rotated_odd = k_even * sin + k_odd * cos
Triton’s response: CompileError: Cannot slice with step != 1
I spent an entire evening trying to convince the compiler to let me slice tensors. No luck.
The Pointer Arithmetic Hack
Then, at 2 AM (why do all breakthroughs happen at 2 AM?), I realized: don’t slice after loading. Load with the slice built-in.
Instead of fighting the compiler at the compute stage, I did slicing at the memory load stage using strided pointers:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Load even rows: row 0, 2, 4, 6...
even_ptr = tl.make_block_ptr(
base=W_KR_ptr,
shape=(d_c // 2, D),
strides=(row_stride * 2, col_stride), # Double the row stride!
offsets=(0, k * BLOCK_K),
block_shape=(BLOCK_N, BLOCK_K),
)
# Load odd rows: row 1, 3, 5, 7...
odd_ptr = tl.make_block_ptr(
base=W_KR_ptr + row_stride, # Start one row later
shape=(d_c // 2, D),
strides=(row_stride * 2, col_stride), # Still double stride
offsets=(0, k * BLOCK_K),
block_shape=(BLOCK_N, BLOCK_K),
)
w_even = tl.load(even_ptr)
w_odd = tl.load(odd_ptr)
The trick: By setting stride to 2 * normal_stride, the GPU’s memory controller automatically skips every other row. No compute-stage slicing needed!
This is like telling a librarian “get me every even-numbered book” instead of asking for all books and filtering yourself.
The Result
Fused RoPE kernel: 22.30ms (vs 21.09ms without RoPE)
Separate kernels would’ve been: ~35ms
Saved ~13ms just by keeping data in registers!
The H100 Plot Twist: When “Fast Hardware” Isn’t Enough
After weeks of optimization on my RTX 3050, I got access to an H100 through Modal.
My expectation: “This has 3.35TB/s memory bandwidth (17x my laptop!). My kernel will SCREAM.”
Reality: It was slower than PyTorch. Again.
1
2
PyTorch (cuBLAS): 1.40 ms
My Kernel: 2.37 ms
Wait, WHAT? After all that work?
The Humbling Reality of Auto-Tuning
The problem: I’d hardcoded block sizes optimized for my RTX 3050 (Ampere architecture). The H100 is Hopper architecture with:
- Bigger caches
- More SMs (Streaming Multiprocessors)
- Tensor Memory Accelerator (TMA)
- Different optimal tile sizes
My kernel was like bringing a bicycle to a Formula 1 race—technically it works, but you’re not using the hardware’s potential.
Enter: The Autotuner
Triton has this magical thing called @triton.autotune that tries different configurations and picks the fastest:
1
2
3
4
5
6
7
8
9
10
11
12
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=5),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4),
# ... 12 more configs
],
key=['M', 'D', 'd_c'], # Autoselect based on problem size
)
@triton.jit
def flash_mla_kernel(...):
# Same kernel code
The autotuner runs every configuration during the first call, times them all, and caches the winner.
Results after autotuning:
1
2
Kernel Time: 2.37ms → 1.54ms (35% faster!)
Peak Memory: 368 MB → 288 MB (21.7% savings)
Still not beating cuBLAS (1.40ms), but way closer. And we’re using less memory, which means:
- Bigger batch sizes possible
- Longer sequences fit
- Better multi-tenancy in production
The Memory Win That Mattered
While chasing speed, I accidentally discovered something valuable:
80 MB saved = 28% larger batch size
In production LLM serving, this is HUGE. It’s the difference between serving 100 requests/second vs 128 requests/second on the same hardware.
Sometimes “almost as fast but more memory-efficient” is more valuable than “slightly faster.”
The Bugs That Almost Broke Me
Bug #1: The Invisible tl
1
2
3
@triton.jit
def my_kernel(...):
x = tl.arange(0, 64) # NameError: 'tl' is not defined
Wait, what? I imported it at the top!
The problem: Triton’s JIT compiler runs in a separate namespace. It couldn’t see my imports.
The solution: Force it into the global namespace before compilation:
1
2
3
4
5
6
7
import triton.language as tl
globals()['tl'] = tl # Make it visible to JIT
@triton.jit
def my_kernel(...):
x = tl.arange(0, 64) # Now it works!
This took me 3 hours to debug. Three. Hours.
Bug #2: The Stride Catastrophe
1
RuntimeError: CUDA error: an illegal memory access was encountered
The kernel was crashing randomly on the H100 but worked fine on my laptop.
After adding bounds checks, print statements, and borderline crying, I found it:
1
2
3
4
5
6
# ❌ WRONG
h.stride(0) # Returns stride of 3D tensor
# ✅ CORRECT
h_flat = h.view(B * T, D) # Flatten first
h_flat.stride(0) # Now returns correct 2D stride
I was passing 3D tensor strides to a kernel expecting 2D data. The H100 has stricter memory alignment, so it caught the bug my RTX 3050 was silently ignoring.
Lesson learned: ALWAYS flatten your tensors before passing to kernels. Always.
Bug #3: Shared Memory Exhaustion
1
OutOfResourcesError: out of resource: shared memory
Some autotuner configs tried to allocate 280KB of shared memory per block. The H100 only has 227KB per SM.
The fix: Let the autotuner fail gracefully:
1
2
3
4
5
6
7
8
@triton.autotune(
configs=[...],
key=['M', 'D', 'd_c'],
reset_to_zero=['c_KV_ptr'],
restore_value=['h_ptr'],
warmup=25,
rep=100
)
The autotuner automatically skips configs that fail to compile. Crisis averted.
The Numbers: Before and After
Development Journey (RTX 3050)
| Stage | Implementation | Time (ms) | Improvement |
|---|---|---|---|
| Start | Naive PyTorch | 10.41 | Baseline |
| Attempt 1 | Naive Triton | 313.55 | -30x (ouch) |
| Optimization 1 | Block Tiling | 21.09 | +14.8x 🎉 |
| Optimization 2 | Fused RoPE | 22.30 | (small regression) |
| Final Tuning | Optimized RoPE | 21.09 | 0.5x vs cuBLAS ✅ |
Production (H100)
| Metric | PyTorch | Triton (Optimized) | Improvement |
|---|---|---|---|
| Latency | 1.397 ms | 1.540 ms | 0.91x speed |
| Peak Memory | 368 MB | 288 MB | -21.7% 🎯 |
| Memory Saved | - | 80 MB | +28% batch capacity |
What This Means in Production
With 80MB savings per forward pass:
- Before: Max batch size 64
- After: Max batch size 82 (+28%)
- Throughput: 100 req/s → 128 req/s
What I Learned (The Real Takeaways)
1. Memory is Everything
On modern GPUs, compute is cheap. Memory movement is expensive.
A kernel that does 2x more math but reads memory once will beat a kernel that does 1x math but reads memory three times.
The profiler is your best friend. If “Memory Throughput” is >80% and “Compute Throughput” is <40%, you have a memory problem, not a compute problem.
2. Block Tiling is Not Optional
For any real workload, you must use tiling to:
- Amortize memory reads across multiple outputs
- Maximize cache reuse
- Keep the GPU busy while waiting for memory
The pattern is always:
1
2
3
Load data tile into shared memory (once)
→ Compute many results using that tile (many times)
→ Repeat
3. Fusion Beats Pipelining
Two separate kernel launches = two round-trips to global memory.
Fusing operations means keeping intermediate results in registers/shared memory.
Even if fusion makes the code more complex, the memory savings usually win.
4. Hardware Matters More Than You Think
Code optimized for Ampere (RTX 30-series) might run poorly on Hopper (H100).
Different architectures have:
- Different cache sizes
- Different optimal block sizes
- Different memory hierarchies
- Different instruction sets
Use auto-tuning. Let the hardware tell you what it wants.
5. “Good Enough” Can Be Better Than “Perfect”
My kernel is 10% slower than cuBLAS but uses 22% less memory.
For production LLM serving where memory is the bottleneck, this is a win.
Don’t obsess over matching cuBLAS—focus on the metric that actually matters for your use case.
The Tools That Saved My Sanity
Development
- Google Colab Pro: Free T4 GPUs for experimentation
- Modal: On-demand H100 access ($2.50/hour—way cheaper than AWS)
- Triton Autotune: Automatically finds optimal configs
Debugging
- NVIDIA Nsight Compute: Shows exactly why your kernel is slow
- torch.cuda.memory_summary(): Tracks memory allocations
- PyTorch Profiler: Identifies bottlenecks in end-to-end training
Learning Resources
- “Programming Massively Parallel Processors” (Kirk & Hwu)
- Triton documentation + tutorials
- Lei Mao’s blog on GPU optimization
- Papers: Flash Attention, FlashAttention-2, DeepSeek-V3
What’s Next: The Unfinished Journey
This project isn’t “done”—it’s just… paused at a good checkpoint.
Near-Term Goals
- FP8 precision: H100’s Tensor Cores can do 2x throughput with
e4m3format - Single-tile optimization: Read input tensors exactly once
- Better RoPE fusion: Shave off that extra 2ms overhead
Moon Shots
- Full MLA attention kernel: Fuse query projection + attention computation
- Multi-GPU: Tensor/pipeline parallelism for really big models
- vLLM integration: Make this production-ready for real serving
- TMA (Tensor Memory Accelerator): Use H100’s async memory engine
The Real Goal
I want to close the gap from 0.91x to 1.0x+ of cuBLAS performance while keeping the memory advantage.
If I can beat NVIDIA’s hand-tuned kernels at their own game—even by 1%—that would be the ultimate validation of everything I learned.
Closing Thoughts: Why This Journey Mattered
Six months ago, I didn’t know what “memory-bound” meant. I thought GPUs were just “CPUs with more cores.”
Now I can:
- Read profiler output and immediately spot bottlenecks
- Design memory access patterns that minimize traffic
- Debug pointer arithmetic in assembly-level IR
- Reason about cache hierarchies and memory coalescing
But more importantly, I learned how to learn hard things:
- Start simple (even if it’s embarrassingly slow)
- Profile obsessively (data > intuition)
- Fix one thing at a time (no cowboy coding)
- Celebrate small wins (15x is huge, even if it’s not 30x)
- Share the journey (you’re reading this, so mission accomplished!)
For Anyone Starting This Journey
If you’re where I was six months ago—confused by terms like “occupancy” and “register pressure”—here’s my advice:
Start with the slowest possible kernel. Make it work. Then make it fast.
Don’t try to write the perfect kernel on day one. You’ll get overwhelmed and quit.
Instead:
- Write naive code that gives correct results (even if it’s 10x slower)
- Profile it (Nsight Compute is free!)
- Fix the biggest bottleneck (usually memory)
- Repeat until you hit diminishing returns
- Ship it
The best code is code that ships. The best optimization is the one you actually finish.
Want to Dive Deeper?
The full code (including all the bugs I hit and fixed) is on GitHub: [Your Repo Link]
Key files:
benchmark_mla.py- RTX 3050 experimentsfast.py- Final optimized kernel with H100 deploymenttest_fast.py- Correctness tests and ablationsresults/- Nsight Compute profiles and pretty charts
If you have questions, war stories, or just want to geek out about GPU kernels:
- Open an issue on the repo
- Find me on LinkedIn
- Tweet at me (I’m probably already talking about GPUs)
Acknowledgments
People who kept me sane:
- The Triton team at OpenAI (especially the Discord community)
- Modal team for H100 access and support
- My laptop’s RTX 3050 (you did your best, buddy)
- Coffee (so much coffee)
Resources that changed everything:
- PMPP textbook (Kirk & Hwu) - Chapter 5 specifically
- Lei Mao’s GPU programming blog
- Flash Attention paper (Tri Dao et al.)
- Random Stack Overflow answers at 3 AM
And you, for reading this whole thing. Seriously. Thanks for caring about the messy middle part of learning, not just the polished final results.
Final Stats (For the Resume)
What I built:
- High-performance GPU kernel using OpenAI Triton
- 15x speedup on consumer hardware (RTX 3050)
- 21.7% memory reduction on production hardware (H100)
- Full development lifecycle from prototype to production
What I learned:
- Memory hierarchy optimization (HBM → L2 → Shared → Registers)
- Block tiling and memory coalescing techniques
- Kernel fusion strategies
- Production debugging on real enterprise hardware
- How to read profiler output and make data-driven optimizations
⭐ Star the repo if this helped you understand GPU optimization better!
💬 Issues/PRs welcome—I’m still learning too!
Made with 🔥 (and way too much debugging)
by someone who finally gets why memory bandwidth matters
P.S. - If you’re hiring GPU optimization engineers, my DMs are open 👀