Teaching AI to Write GPU Code: A Deep Dive into Reinforcement Fine-Tuning

February 14, 2025 · 7 min read
Torch-Triton-Hero
Arnav Garg
Arnav Garg
Travis Addair
Travis Addair
Will Van Eaton
Will Van Eaton

1. Introduction & Motivation

Imagine you have a PyTorch function that works fine for small-scale experiments on a CPU, but it bogs down when running larger tasks on a GPU. Unlocking the full potential of a GPU typically requires rewriting your code into specialized CUDA or Triton kernels—a task that demands deep GPU programming expertise.

While frameworks like Triton make the process more accessible than writing raw CUDA C++, creating optimized, architecture-specific kernels still poses a steep learning curve. This gap between higher-level Torch code and low-level GPU optimization is exactly where reinforcement fine-tuning (RFT) can step in.

Webinar Fine-Tuning DeepSeek 1

Can we teach a LLM how to convert PyTorch code into functional, accurate and optimized Triton/CUDA kernels?

In this post, we’ll discuss how we taught an AI model to convert PyTorch code into efficient Triton kernels using a reinforcement learning algorithm inspired by PPO called Group Relative Preference Optimization (GRPO). We’ll walk through our dataset, the iterative training loop, the reward functions we designed—and how the model gradually learned to produce correct and efficient GPU code.

To learn more about this use case, watch our on-demand webinar “Fine-Tuning DeepSeek: Unlocking the Power of Reinforcement Learning

If you have a use case that could benefit from RFT, join the waitlist for early access!

2. Background on Triton & Reinforcement Learning

Triton in a Nutshell

Triton is a Python-based language and compiler that translates high-level Python code into GPU machine instructions. Think of it as a middle ground between PyTorch and CUDA: you get more control over thread blocks and memory access patterns without fully diving into low-level GPU C++ code.

Why Use Reinforcement Learning (RL) For This Task?

This task is particularly well-suited for using reinforcement learning because:

  1. No Large Labeled Dataset Needed: We started with just a handful of PyTorch code snippetsexamples.
  2. Code is Verifiable: We can deterministically test whether the generated Triton kernels code compiles and produces correct outputs.
  3. Large Search Space: The search space for valid solutions is large and RL balances exploring new Kernel implementations with exploiting proven approaches it learned during model pretraining.
Webinar Fine-Tuning DeepSeek 2

We lack labeled datasets for this task, but have a way to verify any generated output.

Reinforcement learning differs from traditional supervised learning in that it doesn’t rely on large labeled  datasets. Instead, our language model explores possible solutions and receives reward signals for good or bad outputs. Over time, it updates its strategy to maximize rewards.

Group Relative Preference Optimization takes advantage of one or more reward functions that can evaluate a generated code snippet against certain criteria (format correctness, compilation success, runtime behavior).

3. Setting Up the Task

Minimal Dataset

We began with a tiny hand curated dataset of 13 examples, each containing:

  • A PyTorch function (e.g., a matrix multiply or simple activation function).
  • A set of test cases to verify correctness.

Because there isn’t a widely available dataset mapping PyTorch code to Triton kernels, we curated these from open-source GitHub projects by finding valid triton kernels, writing the equivalent PyTorch code ourselves, and adding our own test cases to execute against both the PyTorch code and the Triton kernels.

The actual training data only consists of the 13 PyTorch code examples.

System & User Messages

We constructed appropriate system and user prompts instructing the model to:

  • Create Triton kernels for the given PyTorch function
  • Use specific tags (e.g., <triton_code> ... </triton_code>) so we could extract the kernel cleanly.
  • Import the Triton library.
  • Maintain a consistent function signature and avoid calling PyTorch functions that sidestep the need for real GPU code.
Webinar Fine-Tuning DeepSeek (7) 1

4. Defining the Rewards

Designing a robust reward function is the heart of any RL system. We needed our model to learn formatting, compilation, correctness, and performance (eventually). Here’s how we approached it:

4.1. Formatting Reward

Webinar Fine-Tuning DeepSeek 3

Goal: Encourage a consistent code structure that’s easy to parse and run.

Implementation:

  • String Checks: Did the output include code within <triton_code> tags?
  • Triton Imports: Does the generated code import triton and use @triton.jit?
  • Partial Credit For Good Triton Semantics: We also assigned fractional scores for getting certain parts right, such as only using valid triton language methods, using zeroed out torch output buffers, using masks during load/stores, etc. .
  • Sum partial credits from all the criteria to assign a final score between 0 and 1.

4.2. Compilation Reward

Webinar Fine-Tuning DeepSeek 4

Goal: Reward the model for generating code that compiles and executes without throwing exceptions (no syntax errors or invalid calls) even if it gets the final answer wrong

Implementation:  If the code could be executed in a separate Python process without crashing, the model receives a positive reward of 1

4.3. Correctness Reward

Webinar Fine-Tuning DeepSeek (4) 1

Goal: Ensure the kernel’s output matches that of the PyTorch function on multiple test inputs.

Implementation:

  • Multiple Test Cases: We started with two test inputs, then expanded to four for finer-grained feedback.
  • Reward Scaling: Calculate (total number of test cases that passed ) / (total number of test cases), which is a value between 0 and 1. 
  • Anti RewardHacking: We monkey-patch the kernel call to detect if the model just returned a PyTorch operation result (instead of letting the Triton kernel do the work). If the outputs matched when the kernel was replaced with a no-op, the reward was set to 0.

5. Training Loop & Iterations

5.1. How GRPO Works

  1. Generate: For each prompt (PyTorch code snippet), generate N completions using temperature-based sampling.
  2. Evaluate: For each generated completion, run our reward checks (format, compilation, correctness) and assign a reward per reward function.
  3. Update: Compute advantages (which completions outperformed the average and which ones did worse than the average) and backpropagate those signals into the model’s parameters to update them. We use LoRA as our training method. .
  4. Repeat: Over thousands of steps, the model refines its strategy to maximize rewards. Typically, this it first learns the format based rewards before learning how to maximize other rewards
GRPO

5.2. Early Challenges & Refinements

Reward Hacking: The model initially learned shortcuts, like returning the result of torch.sum() instead of truly computing a sum via a Triton kernel. This was a form of clever reward hacking.

  • Fix: We penalized completions that succeeded on test cases where the output from the initial kernel execution matched a second run where we replaced the generated kernel with a no-op kernel.
reward-hacking

Sparse Rewards: Binary pass/fail gave the model little direction when it was “almost right.”

  • Fix: Introduced partial credit, particularly in the format reward function and via the introduction of a compilation based reward, so generating code that compiles but fails correctness still earns a small reward.

Limited Test Cases: With only two test cases, the model wasn’t getting enough signal if it was making directional progress towards a correct kernel.

  • Fix: We doubled to using four test cases, giving a more nuanced reward signal if the kernel leads to even getting a subset of these calls correct

6. Results & Discussion

Over roughly 5,000 training steps, our model’s accuracy on held-out examples climbed to about 40% (meaning 40% of the time, the generated Triton kernel fully matched the PyTorch outputs on all test cases). Although 40% may sound modest, it’s a remarkable jump from near-0% at the start:

Webinar Fine-Tuning DeepSeek 6

Learning curves from initial training run.

  • Faster to Format Compliance: Within ~100–200 steps, the model reliably produced code that included the correct tags and imports. Within 1000 steps, it learned to get most parts of the formatting reward function correct.
  • Gradual Correctness Gains: True correctness took longer to learn, rising steadily once the model nailed down syntax and compilation.
  • Learned to Avoid Hacking: Through our monkey-patching based detection algorithm, the model eventually recognized no reward could be gained by cheating, so it focused on improving its Triton kernel implementation.
Webinar Fine-Tuning DeepSeek (9) 1

Learning curves from final training run with partial credit, anti-reward hacking measures and compile reward function.

Example Triumphs

  1. Parallelization: In some outputs, the model set up block sizes and memory loads correctly, matching best practices for GPU performance.
  2. Buffer Initialization: It accurately used torch.zeros or torch.full for output buffers, ensuring consistent final results.
buffer

Persistent Gaps

  • Edge Cases: Some kernels failed for unusual input shapes or large sizes.
  • Performance: Our reward function didn’t yet incentivize minimal runtime, so the code was correct but not always speed-optimized.

7. Conclusion & Future Outlook

This exploration showcases how reinforcement fine-tuning can tackle complex code-generation tasks without a massive labeled dataset. By iterating on reward functions, we steadily pushed our model to generate correct Triton kernels that compile, run, and match PyTorch outputs using just 13 examples of PyTorch code.

Where We Go Next

  1. Performance Optimization: Introducing a runtime-based reward could encourage faster kernels, pushing the model to learn better memory layouts and parallelization strategies.
  2. Larger Test Suites: More test cases = more granular feedback, helping it learn better. 
  3. Generalizing to Other Tasks: This method extends beyond PyTorch ↔ Triton. We can adapt it to tasks like Java → Python transpilation, SQL optimization, or any code with verifiable correctness.

Try It Yourself

If you’re interested in experimenting with reinforcement fine-tuning for code generation or other tasks, explore our new RFT features at pbase.ai/rft. With built-in tools for custom reward functions, partial-credit scoring, and code execution, you can design your own RL agent to tackle the next big code challenge.

FAQ

What is Reinforcement Learning (RL)?

Reinforcement Learning (RL) is a type of machine learning where an agent learns to make decisions by interacting with an environment. It receives rewards for good decisions and penalties for bad ones, and it uses this to optimize its actions over time to maximize rewards in the long run.

How is Reinforcement Fine-Tuning (RFT) different from traditional supervised fine-tuning?

Reinforcement Fine-Tuning (RFT) differs from traditional supervised fine-tuning in that it doesn’t rely on a  labeled datasets. Instead, it uses a reward-based system to iteratively improve model performance based on predefined success criteria (defined in reward functions) and the outputs produced by the model after each step of training.

Why is Reinforcement Learning useful for generating GPU code?

Writing optimized GPU code requires expert knowledge of memory management, thread execution, and parallelization. RL enables AI to explore different code implementations, receiving rewards based on correctness, compilation success, and efficiency.

What is Triton, and why is it important for GPU optimization?

Triton is a Python-based compiler that allows users to write high-performance GPU kernels without delving into CUDA C++. It bridges the gap between high-level deep learning frameworks like PyTorch and low-level GPU programming.

How does RL optimize Triton kernel generation?

In this approach, an RL-trained model generates Triton kernels from PyTorch code. It receives rewards based on:

  • Formatting correctness (proper syntax, imports, and structure)
  • Compilation success (ensuring the code runs without errors)
  • Output accuracy (matching PyTorch function results on a few example input-output test cases)
  • Potential future enhancements like runtime efficiency

What is Group Relative Preference Optimization (GRPO)?

GRPO is an RL-based optimization method similar to Proximal Policy Optimization (PPO). Unlike PPO, which requires training a separate reward model or value network to estimate returns, GRPO eliminates this step by using programmable, verifiable reward functions and evaluating actions based on their relative performance within a group. This is achieved by generating multiple outputs for each input, scoring them using a set of reward functions, and computing advantages (how much better one generated output is to another for the same input) by comparing each action's reward to the group's average. This gradually helps the model learn how to maximize scores on the reward functions and learn the task. 

How much training data is required for Reinforcement Fine-Tuning?

Unlike supervised learning, RFT can start with a small dataset. In the Triton kernel optimization example, only 13 labeled examples were initially used, with RL iteratively improving performance over thousands of training steps. Our experiments on other tasks show that RL performance can continue to improve with a large diversity of inputs, so you can keep increasing the dataset size for even better performance since it will help the model learn a more general strategy.

How can I try Reinforcement Fine-Tuning?

You can join the waitlist for early access to RFT on Predibase at pbase.ai/rft!

Related Articles

Join Our Community!