We recently hosted an episode of our ML Real Talk Series on the topic of adapter-based training and why techniques like LoRA (Low-rank Adaptation) are a game-changer for fine-tuning. Adapters provide a cost-effective and streamlined approach to leveraging open-source LLMs to achieve better performance at lower cost than other methods. In this blog, we cover the key takeaways from the conversation to help you get started using adapters for fine-tuning LLMs in your day-to-day work.
Brief Intro to Fine-tuning and Benefits of Adapter-based Training
It’s important to first understand the fundamentals so let’s start with key definitions:
- Fine-tuning: Fine-tuning plays a vital role in machine learning, particularly when dealing with extensive pre-trained open-source models like Mistral, Zephyr, and Gemma. The process of fine-tuning involves updating the model's parameters by feeding it new task-specific data and adjusting its weights based on the expected output. Through backpropagation, a loss is calculated and the model's weights are adjusted to improve its performance on similar inputs in the future, aiming to enhance its task-specific performance without compromising its overall capabilities.
- Adapters: Adapters are just a small set of model weights produced during fine-tuning. In a couple sections below we’ll discuss how you can efficiently train highly performant adapters.
Fine-tuning: Showing your LLM new data and altering its weights
When to fine-tune and differences with RAG
- RAG (Retrieval Augmented Generation) is another method for improving the response of an LLM by injecting new data into the prompt at the time of the request. Conversely, fine-tuning modifies the model itself. In RAG, the new data is segmented and utilized by the Language Model (LLM) to respond to queries. For instance, if you provide documentation on pandas and ask the model how to read a data frame from a CSV file in pandas, the retrieval model will extract relevant information from the documentation to assist the LLM in generating a response. Essentially, RAG focuses on reading comprehension rather than solely relying on the model's internal weights to answer questions.
- Fine-tuning is particularly effective at teaching an LLM domain expertise–such as classifying legal documents–or adapting an LLM to a specific tone or style of communication. While RAG is well-suited for querying databases or documents, fine-tuning allows LLMs to generate new content related to a specific domain, such as generating code in a private coding language or summarization in a specific domain and voice. In addition, fine-tuning can teach LLMs how to provide responses in a specific output format, such as producing a JSON output without systematic preamble like "Sure, happy to do that."
Legacy fine-tuning challenges
As explained before, full fine-tuning entails adjusting all the model's weights. This can be slow as it requires computing gradients across all weights to minimize the loss between the actual output and the model's generated output. This comprehensive modification also involves billions of floating-point computations and constant data movement within the GPU memory hierarchy.
Fine-tuning can also be memory-intensive as it requires storing both gradients and optimizer states, effectively doubling the memory requirements. For instance, when using the Adam optimizer, a common rule of thumb suggests allocating three times the GPU RAM as the model size in memory during the backward pass. This accounts for storing all model parameters and additional optimizer state information (two floats per parameter). Consequently, if your model fits within an NVIDIA A10G GPU, training may necessitate four such GPUs.
Low-Rank Adaptation (LoRA) to the rescue: Fine-tuning made fast and efficient
That's where adapters come in. Adapters allow the LLM to adapt to new scenarios without changing its original parameters. This preserves the LLM's general knowledge and avoids catastrophic forgetting when learning new tasks. Adapters can also reduce the number of parameters that need to be updated, reducing the time and compute required—training overhead can be reduced up to 70% compared to full fine-tuning.
A popular and efficient technique for training adapters is Low-Rank Adaptation (LoRA). LoRA operates on the premise that the model is already near the optimal outcome and minimizes the need for extensive modifications. LoRA operates by freezing the original weights and only changing the difference in the cost function. To further compress, LoRA transforms the adapter into an optimized decomposition and subsequently lowers the rank of these matrices. These techniques produce an adapter that is equally as performant compared to full fine-tuning while training substantially fewer weights and using less storage.
Adapter-based fine-tuning: Freezing the LLM’s original weights and surgically inserting new ones using the A-B decomposition of the weights gradient matrix.
5 reasons why adapters are the future of fine-tuning
Now that we have a baseline knowledge of fine-tuning, adapters and LoRA let’s dig into the reasons why adapters are the future of fine-tuning.
1. No trade-off in performance
In a study published in 2021, LoRA was evaluated against complete fine-tuning of RoBERTa and DeBERTa across various benchmarks. The results indicated comparable performance levels overall, with slight superiority of full fine-tuning observed in tasks such as SST-2 (sentiment analysis). Conversely, in tasks like MNLI and on average, LoRA demonstrated superior performance compared to full fine-tuning specifically on DeBERTa.
From LoRA paper: LoRA fine-tuned model performance on par with full fine-tuned model performance.
2. Extremely low memory footprint: 0.2%
If we look at the total amount of trainable parameters for each model below, you can see that RoBERTa large, full fine-tuned, is training in excess of 350 million parameters while DeBERTa XXL is training in excess of 1.5 billion parameters. Comparatively, the LoRA trainable parameters are 0.2% for RoBERTa and 0.3% for DeBERTa which represents a very large performance gain in terms of memory footprint.
But how can we reach such performance gain? The secret lies in the fact that the adapter is decomposed into the A and B matrices with reduced rank. The final ranks can be drastically reduced: as low as 8, 4, 2, or even 1! For instance, if the original weight matrix was 1024x1024 = 1,048,576, the new set of weights would be 2x1024x8 = 16,384, representing a mere 0.2% of the original weight count.
Achieving 99.8% reduction of the weights through rank reduction.
This reduces storage costs by necessitating storage of only the original frozen W matrix and the minimal subset of weights. The latter become negligible, and can essentially be disregarded. Consequently, fine-tuning can be achieved using the same infra used for model deployment.
Additionally, a notable benefit is observed in checkpointing processes. Writing checkpoints during the training of large models can be time-consuming due to data transfer between GPU memory, CPU memory, and disk. For really big models like Llama-70B, a separate job may be required to handle checkpointing asynchronously due to the substantial time needed to offload weights to disk. However, with LoRA, only 0.2% of the weights need to be optimized. This step then becomes nearly instantaneous as the data transferred is reduced from gigabytes to megabytes. This efficiency allows for more frequent checkpointing without significant delays.
3. Faster to train
Another advantageous outcome of this reduced memory usage is that the batch size can be increased, thereby accelerating model training. For example, LoRA can reduce your memory footprint to a quarter. If your initial compute resources were ready for full fine-tuning, then you can quadruple or even octuple your batch size. This adjustment can make your training significantly faster by better utilizing your existing resources.
4. Unlocks multi-model deployments
One more benefit is the ability to factorize your base model to different deployed fine-tuned models with LoRA Exchange (LoRAX). For instance, if you had 100 fine-tuned models derived from LAMA 7b to address the unique requirements of 100 different clients, theoretically you would need to deploy 100 deployments and constantly host the models for just a fraction of your overall traffic, which can be very costly as you scale the amount of LLMs in production.
With LoRAX, you can host all of these models with just one single deployment as the transfer of these smaller adapters from CPU to GPU now occurs within milliseconds. Adapters can be loaded from object storage into GPU memory in a non-blocking manner just-in-time during runtime. Leveraging the compact memory footprint of adapters (approximately 10MB), loading speeds of less than 200ms have been achieved, proving cost-effective especially when the adapter is utilized for generating numerous tokens over time.
Cost of Serving Fine-Tuned LLMs: LoRAX vs Dedicated vs GPT-3.5-Turbo.
Behind the scene, a strategic approach is employed to offload adapters from GPU to CPU to disk using a "least-recently-used" policy, effectively preventing adapter overload from causing CUDA Out Of Memory errors. Furthermore, with insights from the work of Chen et al., multi-adapter batching is implemented, ensuring a sub-linear increase in latency with each adapter integrated into the LoRAX server. In essence, by adopting LoRAX, you can leverage hundreds of fine-tuned models at a cost equivalent to deploying a single Large Language Model (LLM), showcasing significant efficiency and scalability benefits.
Deployment of only one base model for several fine-tuned model in LoRAX.
5. They are only getting better!
We are currently engaged in expanding our capabilities beyond serving multiple adapters for a single model. Our focus extends to incorporating text embedders for enhancing your RAG system, text classification functionalities, and multihead decoders like Medusa to amplify processing speeds by threefold or more. This advancement enables the seamless integration of various tasks within a unified workflow.
For instance, envision a scenario where a RAG system processes a query by embedding it, retrieving pertinent documents, performing text classification to discern fraudulent content, and subsequently generating a “request is being processed” email response for the user using Medusa. This comprehensive process can be seamlessly executed through a single deployment, streamlining operations and enhancing efficiency.
Soon to come wide variety of fine-tuned model served using a single base model.
Best practices
After covering the benefits of adapters, we wanted to share some best practices from the field on how to perform adapter-based training.
1. How much data do you need to train an adapter?
When training an adapter, having the right amount of data is crucial. While best practices are still evolving, a few hundred data points are generally sufficient for fine-tuning an adapter. Use your intuition and experience in this heuristic process.
2. Use synthetic data
If you don’t have enough data but still wish to train an adapter, maybe because GPT-4 is too expensive or you have some privacy constraints down the road, you can use synthetic data to train your adapter.
3. Standardize around a single base model
Make the most of the LoRAX and use multiple adapters around a single base model. Your system will become more effective, and they will be easier to compare.
Overview of training data and model fine-tuning in Predibase
To illustrate, let's delve into the process of training an adapter similar to those utilized in LoRA Land using Predibase. LoRA Land serves as a demonstration platform we have created on our website, showcasing over 25 finely-tuned Mistral 7B adapters that are on par or outperform GPT-4. Notably, all these adapters are efficiently hosted on a single GPU on our end, with automatic scaling capabilities in place to accommodate increased demand seamlessly.
Initially, we establish a connection to our SDK :
%load_ext autoreload
%autoreload 2
from predibase import Predibase
pb = Predibase(api_token="my-api-token")
Then, we select our deployment and get the reference to the client for doing inference:
client = pb.deployments.client("mistral-7b")
Subsequently, we proceed to fine-tune the LLM by providing a dataset for training. The dataset has to have two columns: “prompt” and “completion”, where “completion” is our target, while the “prompt” is the “input” column’s value, formatted with the composite prompt_template below.
base_model_prompt_template = "<s>[INST] {prompt} [/INST]"
fine_tuning_prompt_template = """
Your task is a Named Entity Recognition (NER) task. Predict the category of each entity,
then place the entity into the list associated with the category in an output
JSON payload. Below is an example:
Input: EU rejects German call to boycott British lamb .
Output: {{"person"=[], "organization":["EU"], "location": [], "miscellaneous": ["German", "British"]}}
Now, complete the task.
Input: {input} Output:
"""
prompt_template = base_model_prompt_template.format(prompt=fine_tuning_prompt_template)
repo = pb.repos.create("my_adapter", description="Fine-tuning on NER dataset with Predibase.")
adapter = pb.finetuning.jobs.create(
config={
"base_model": mistral-7b,
"epochs": 3,
},
dataset=dataset,
repo="my_adapter",
description="Fine-tuning an adapter on NER dataset with Predibase.",
)
Once the model has been trained, we execute the inference using that adapter and retrieve the response:
client.generate(
base_model_prompt_template.format(
prompt=fine_tuning_prompt_template
).format(input="MLK was born in Atlanta"),
adapter_id="my_adapter/1", # here, version 1 is used (can be any number)
max_new_tokens=128,
).generated_text
By accessing the UI, you can gain insights into the training process involved in developing this model.
Visualize results and metadata associated with the training in Predibase.
The UI also offers a user-friendly visualization of the learning curves, which proves invaluable for debugging purposes.
Learning curves displayed on the UI for easy debugging.
For instance, you may observe a scenario where the training loss decreases while the test loss increases. This discrepancy indicates that your Large Language Model (LLM) is overfitting, essentially memorizing the training data without generalizing well to new examples. In such instances, implementing regularization techniques like dropout or weight decay can help mitigate overfitting issues.
Next Steps
We hope this discussion was valuable to you. You can now watch a replay of our webinar on YouTube (including our live demo).
Get started today with a free trial and follow our quick-start guide to fine-tune and serve a number of popular models like Code Llama, LLaMA-2 and Mistral for free on Predibase! You can explore these insights either through the UI or via the SDK within your own Jupyter notebook for a hands-on experience.
We can't wait to see what you'll build!