Riding the wave of open-source LLaMA's success, Meta has unveiled LLaMA-2, a remarkable leap in the world of open-source large language models. Offering 7B, 13B, and 70B parameter variants, LLaMA-2 stands shoulder-to-shoulder with the likes of ChatGPT in terms of performance. It boasts not only improved quality and double the context length but also a free commercial use license which makes it a great base LLM to use for your tasks. However, it typically won’t work well out of the box for your specific ML task since it was trained on general text data from the web during the pretraining stage.
The fine-tuning opportunity and challenge
If you want to use it for your use case, you will likely want to fine-tune the model on your task-specific data. Our experiments have shown that we can get significantly better performance on tasks such as JSON generation by fine-tuning LLaMA-2 variants compared to using GPT3.5 and GPT4.
Fine-tuned LLaMA-2 outperforming GPT4 for JSON Generation — 250 times smaller, 40% more accurate in our experiments!!
Fine-tuning can be very effective, but comes with its own set of challenges. Most engineers and data scientists attempting to fine-tune models quickly realize that productionizing open-source models is harder than it seems.
Three primary challenges that we see teams struggle with when fine-tuning LLMs:
- Complex Tooling: staying abreast of the latest fine-tuning techniques from research (e.g. LoRA, quantization, ec), learning how to implement, and then stitching together various open-source tools and frameworks is a cumbersome task for any individual.
- Unreliable Fine-Tuning: high-end GPUs (like A100s) are in short supply. As a result, most teams are left renting commodity GPUs that they struggle to use effectively.This results in frequent out-of-memory errors among other errors that bring training to a halt. Equally as important is setting proper defaults to avoid overfitting and NaN loss.
- Costly Model Serving: building a production system that meets SLAs and scales efficiently for a growing set of fine-tuned task specific models requires deep expertise in building serving infra—which frankly is out of reach for most eng teams.
Predibase makes fine-tuning Llama-2 (and any open-source model) easy
We have architected Predibase to solve the infra challenges plaguing engineering teams and make fine-tuning as easy and efficient as possible in the following ways:
- Abstracting away writing the code for fine-tuning using Ludwig, an open-source project that lets you define model training through configurations. This makes it easy to iterate on your fine-tuning journey, try different parameter-efficient strategies, experiment with different variants of the prompt you are using for fine-tuning, and tie in the ability to prompt and fine-tune at the same time.
- Making it easy to iterate on prompt templates: The choice of prompt can have a significant effect on how well the model can be trained. For example, when fine-tuning LLMs that were previously instruction-tuned like LLaMA-2-chat, training is more effective when using the same prompt template as the LLM was trained with. In Predibase, the prompt template, data string, and task are treated as separate concepts, so it's trivial to iterate on each independently without having to rewrite your preprocessing code or manually maintain many different versions of your data.
- Automagically right-sizing compute resources to scale LLM fine-tuning jobs to your tasks, both in terms of the dataset and model size, so that your fine-tuning task is always successful. This takes away the hassle of setting up your infrastructure for training, figuring out how to perform distributed model training at scale, and running into a variety of challenges with memory-pressure issues on both CPU and GPU memory.
- Distributed training out of the box: Predibase handles generating the appropriate Deepspeed configuration for model sharding across GPUs, enables half-precision training, appropriate offloading of parameters and optimizer states to maximize throughput while maximizing CPU and GPU utilization, and a lot more. All of this is done behind the scenes so you don’t have to worry about configuring these parameters. This makes your training process work reliably and in a scalable manner while keeping costs low.
- Dynamically serve 100s of fine-tuned models for the price of one: Predibase’s scalable serving infra automatically scales up and down to meet the demands of any production environment. On top of that, we implement a novel serving architecture called LoRA Exchange (LoRAX). This approach allows your team to serve many fine-tuned LLMs together for over 100x cost reduction versus dedicated deployments. Fine-tuned models can be loaded and queried in seconds.
Four Easy Steps to Fine-Tune Llama-2 in Predibase
Predibase makes it very easy to fine-tune your large language model and customize it to your task. It just requires 4 steps:
- Download the Predibase SDK and authenticate it using your credentials
- Connect your dataset in 2 lines of code
- Configure your fine-tuning job in 4 lines of code
- Kick off your fine-tuning job in 1 line of code
For this short tutorial, we will fine-tune LLaMA-2-7b on a small subset of Code Alpaca dataset using QLoRA for parameter-efficient fine-tuning.
The goal is to use Llama-2-7b for code generation. The model will take natural language as input and should return code as output. We're first going to see how base Llama-2-7b does with just prompting, and finally fine-tune the model.
As an example, if we prompt the model with this instruction:
Create an array of length 5 which contains all even numbers between 1 and 10.
We want the model to produce exactly this response:
array = [2, 4, 6, 8, 10]
If you’d like to follow along to the rest of this blog post, check out this Colab notebook! If you haven’t already signed up for the Predibase free trial, you can start here, which will be required as you work through the notebook.
1. Downloading the Predibase SDK and Logging In
To get started, you'll need to download the Predibase SDK and grab your Predibase API token. We recommend using the Predibase SDK either through Google Colab or through a Jupyter Notebook. You can install the SDK inline by running the following command:
!pip install -U predibase
Next, in a new cell, run the following command to login to Predibase:
!pbase login
This will prompt you for your Predibase API token. To grab your token, navigate to the Settings page and click Generate API Token.
Generating an API token in Predibase takes just one click
Once you’re authenticated, you should see a confirmation like this.
Example API token confirmation message
2. Connect Your Dataset
You can either use the full version of the dataset or download a small subset of the dataset that has 800 rows using the following command:
!wget
https://predibase-public-us-west-2.s3.us-west-2.amazonaws.com/datasets/code_alpaca_800.csv
For this blogpost, we’ll make use of this small subset.
Next, you need to initialize the Predibase Client using the SDK, which is the main entry point to interact with the entire Predibase SDK ecosystem.
from predibase import PredibaseClient
pc = PredibaseClient()
Finally, you can upload the dataset to Predibase from your notebook.
dataset = pc.upload_dataset("code_alpaca_800.csv", name="code_alpaca_800_demo")
3.1 Prompt Base Models
Before we fine-tune, we should see how Llama-2-7b (base and chat variants) does on a simple example of the task we want our LLM to do.
When you sign up for Predibase, we have a variety of shared multi-tenant LLM deployments that you can query out of the box, including Llama-2 (7b, 13b, 70b base and chat variants), Mistral-7b and Zephyr-7b. We’re always updating these based on models that show great out-of-the-box performance and become popular. It’s easy to query these deployed models out of the box using the SDK. You can also optionally deploy any causalLM on HuggingFace using Predibase.
To see what deployments are active at any time, you can run:
pc.list_llm_deployments()
Predibase supports all variants of Llama-2 right of the box from 7B to 70B, with the option to spin up dedicated deployments using any open-source hugging face model.
For our use case, let’s query Llama-2-7b base and Llama-2-7b-chat. This can be done using just 2 lines of code. Line 1 grabs the base deployment that we want to query, and line 2 prompts that with a custom input.
llm_deployment = pc.LLM("pb://deployments/llama-2-7b")
result = llm_deployment.prompt("""
Below is an instruction that describes a task, paired with an input
that may provide further context. Write a response that appropriately
completes the request.
### Instruction: Write an algorithm in Java to reverse the words in a string.
### Input: The quick brown fox
### Response:
""", max_new_tokens=256)
print(result.response)
Here’s the output response from the base Llama-2-7b model:
### Instruction: Write an algorithm in Java to reverse the words in a string.
### Input: The quick brown fox
### Response:
### Instruction: Write an algorithm in Java to reverse the words in a string.
### Input: The quick brown fox
### Response:
### Instruction: Write an algorithm in Java to reverse the words in a string.
### Input: The quick brown fox
### Response:
### Instruction: Write an algorithm in Java to reverse the words in a string.
In this case, we can see that it just returns the input over and over again until the token limit is reached. It also doesn’t know when to stop producing a response. This is pretty common when the model doesn’t know how to respond.
If we now prompt the instruction tuned Llama-2-7b-chat model (using "pb://deployments/llama-2-7b-chat
" as the base deployment), we get the following response:
Of course! Here is an algorithm in Java to reverse the words in a string:
public static void reverseWords(String str) {
// Create a new string to hold the reversed words
String reversedStr = "";
// Split the input string into individual words
String[] words = str.split(" ");
// Reverse the order of the words
for (int i = words.length - 1; i >= 0; i--) {
reversedStr += words[i];
}
// Return the reversed string
return reversedStr;
}
Using this algorithm, if we call the `reverseWords()` method with the input string "The quick brown fox", it will return the reversed string "fox brown quick".
In this case, the chat variant of the model does produce valid code, which is already a significant improvement over the base model. However, it tends to be “chatty” and adds a bunch of additional explanation text before and after to provide more context, which is very much in line with what you’d expect a model tuned for chat to respond with.
What we want is for the model to produce only correct, working code.
3.2 Configure Your Fine-Tuning Job
To start fine-tuning with the Predibase SDK, the first thing you need to do is pick the LLM you’d like to fine-tune. We’ll be fine-tuning Llama-2-7b-chat in this blog, so we can initialize the base LLM model for fine-tuning using a single line of code.
# Specify the Huggingface LLM you want to fine-tune
llm = pc.LLM("hf://meta-llama/Llama-2-7b-chat-hf")
Once you pick the LLM you want to fine-tune, you can optionally decide what kind of fine-tuning you’d like to do. Predibase offers 3 fine-tuning options out of the box:
- 4 bit QLoRA fine-tuning (default)
- 8 bit QLoRA fine-tuning
- bf16 LoRA fine-tuning
To see all the template options, you can run the following commands:
tmpls = llm.get_finetune_templates()
tmpls.compare()
which will produce the following output:
Predibase provides best practice fine-tuning templates out of the box
Each template variant trades off fine-tuned model performance with the cost of training, so pick the one that makes the most sense for your use case and training budget.
If you do want to perform 8-bit or bf16-based LoRA fine-tuning, you can pick the appropriate template using the template name.
# Select a template of your choice and fine-tune it!
my_tmpl = tmpls["lora_bf16"]
Next, we’re going to define a prompt template. In our case, the code alpaca dataset has 3 columns: instruction
, input
, and output
.
Sample of the Code Alpaca dataset structure
The instruction
and optional input
columns will be used as inputs
, and the output column is what we’re going to teach the model to predict. These columns can be inserted into the prompt template using {}
as we’ve done below. This prompt template will be applied to every row in your dataset during fine-tuning.
In general, prompt templates are useful when fine-tuning since they help provide context that guides the model during the fine-tuning process, and they help boost fine-tuned model performance when your fine-tuning dataset is very small. See this paper here if you’re interested in learning more about why prompts are useful when fine-tuning. Here’s the prompt template we’re going to use for our task:
# Define the template used to prompt the model for each example
prompt_template = """Below is an instruction that describes a task, paired with an input
that may provide further context. Write a response that appropriately
completes the request.
### Instruction: {instruction}
### Input: {input}
### Response:
"""
The last piece is to let Predibase know what the name of our output feature is, which in our case, is just called output
.
With all of these pieces set up, we just have to make a call to llm.finetune()
method and pass these objects in.
We can put all these pieces together in a single code block using the default 4-bit QLoRA template:
# Define the template used to prompt the model for each example
prompt_template = """Below is an instruction that describes a task, paired with an input
that may provide further context. Write a response that appropriately
completes the request.
### Instruction: {instruction}
### Input: {input}
### Response:
"""
# Specify the Huggingface LLM you want to fine-tune
llm = pc.LLM("hf://meta-llama/Llama-2-7b-chat-hf")
# Configure your fine-tuning job
job = llm.finetune(
prompt_template=prompt_template,
target="output",
dataset=dataset,
# repo="optional-custom-model-repository-name"
)
# Wait for the job to finish and get training updates and metrics
model = job.get()
That’s it! You just have to run this cell to start fine-tuning on Predibase!
Beneath the surface, Predibase does a variety of optimizations that are abstracted away:
1. Right-Sizing Compute: Based on your dataset and model size, Predibase intelligently decides what compute would be suitable for your task, and always prioritizes the GPU configuration options that have higher availability and are cheaper for you to use. This scales from a single T4 GPU for Llama-2-7b finetuning using QLoRA to multiple A100 GPUs for Llama-2-70B LoRA-based fine-tuning without quantization.
2. Fine-Tuning Optimizations: Predibase does a variety of fine-tuning optimizations to make your job work efficiently on smaller commodity GPUs for moderate-sized LLMs and also when fine-tuning Llama-2-70B. Some of these include: 4 and 8-bit quantization; utilizing LoRA for parameter-efficient fine-tuning (known as QLoRA when using quantization); Gradient/Activation checkpointing; Gradient accumulation; Paged optimizers; Micro and macro batch size tuning; and Sequence length tuning.
3. Built-in fault tolerance for training: Predibase also has a variety of fault tolerance mechanics built into the system during training, including but not limited to:
- Handling CPU and GPU out-of-memory errors with smart retrying logic to make the job succeed
- Implementing error handling and retry mechanisms for communication failures, network issues, or other transient errors that may occur during distributed training
- Saving intermediate checkpoints to continue training from points of failure
4. Run the SDK command
When you run model = job.get()
, you begin to see the following logs in the SDK:
Example of Predibase SDK logs
Predibase automatically creates a repository to track your fine-tuning jobs and experiments, shows you what compute is allocated for your training job, logs training progress based on the number of steps in the epoch, and finally logs metrics after each epoch of fine-tuning on each of the training, validation, and test sets.
In this case, you can see that the loss decreased epoch on epoch across both the training and test sets and the BLEU score also increased on both sets. These are strong indicators of good learning during the fine-tuning process.
You also get a link to the model’s training progress in Predibase’s App/UI if you are interested in tracking progress that way:
Predibase provides model training dashboards to track progress
Serving Fine-Tuned Models Using LoRAX
Once your model is done fine-tuning, it’s quick and easy to test its performance.
At Predibase, we’ve built a new type of LLM serving infrastructure optimized for productionizing many fine-tuned models together with a shared set of GPU resources – an approach we call LoRA Exchange (LoRAX). Unlike conventional methods for serving large language models, LoRAX allows users to pack upwards of 100 fine-tuned task-specific models into a single GPU, dramatically reducing the cost of serving fine-tuned models.
To use LoRAX, you just need a reference to your fine-tuned model and the base deployment it was fine-tuned on. Then you can dynamically query your model using a simple call to deployment.prompt(
). Predibase also supports a REST API that you can use to query your fine-tuned models.
# Since our model was fine-tuned from a Llama-2-7b base, we'll use the shared deployment with the same model type.
base_deployment = pc.LLM("pb://deployments/llama-2-7b-chat")
# Now we just specify the adapter to use, which is the model we fine-tuned.
adapter_deployment = base_deployment.with_adapter(model)
# Recall that our model was fine-tuned using a template that accepts an {instruction}
# and an {input}. This template is automatically applied when prompted.
result = adapter_deployment.prompt(
{
"instruction": "Write an algorithm in Java to reverse the words in a string.",
"input": "The quick brown fox"
},
max_new_tokens=256)
print(result.response)
Running inference on our fine-tuned model using the same input produces the following output:
import java.util.Scanner;
public class ReverseWords {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
String input = scanner.nextLine();
String reversed = new String();
for (int i = input.length() - 1; i >= 0; i--) {
reversed += input.charAt(i);
}
System.out.println(reversed);
}
}
The generated ReverseWords class in Java from our fine-tuned model reads a line of text from the standard input (usually the keyboard) using the Scanner class, reverses the order of the characters in the input string, and then prints the reversed string to the standard output (usually the console).
This is an example of fine-tuning performance you can expect to see even with just 800 rows of data on the smallest variant of Llama-2. You can expect to see even better performance when fine-tuning on larger datasets using larger Llama variants like Llama-2-13B and Llama-2-70b, both of which are supported by Predibase.
Start Customizing LLama-2 with Predibase’s Free Trial
As you can see, Predibase makes it very easy to go from fine-tuning datasets to serving fine-tuned models using just a few commands in the Predibase SDK!
Interested in trying this out on your own? Sign-up for a $25 in free credits with a free trial of Predibase (no credit card required!) and you can fine-tune and serve any open-source LLM!
Acknowledgements
Thanks to Abhay Malik and Michael Ortega for their contributions to this blog post.