LoRA Exchange (LoRAX): Serve 100s of Fine-Tuned LLMs for the Cost of 1

October 18, 2023 · 8 min read
LoraxBlog-SocialCard
Travis Addair
Travis Addair
Geoffrey Angus
Geoffrey Angus

Developers looking to put generative AI into production using their organization’s proprietary data are quickly discovering that smaller and faster specialized LLMs like LLaMA-2-7b beat bigger general-purpose models like GPT-4 when fine-tuned to a particular task. But to get all those fine-tuned LLMs into production, serving your own collection of specialized LLMs needs to be at least as cost effective as using one general-purpose LLM through an API – an impossibility when every fine-tuned LLM requires its own set of dedicated GPU resources.

At Predibase, we’ve built a new type of LLM serving infrastructure optimized for productionizing many fine-tuned models together with a shared set of GPU resources – an approach we call LoRA Exchange (LoRAX). Unlike conventional methods for serving large language models, LoRAX allows users to pack upwards of 100 fine-tuned task-specific models into a single GPU, dramatically reducing the cost to serve fine-tuned models.

In this blog post, we’ll introduce our state-of-the-art solution to the challenges of serving fine-tuned LLMs in production, explain how Predibase’s LoRAX serving infrastructure works under the hood, and show you how you can get started fine-tuning and serving LLMs in Predibase today for free. LoRAX was also recently open-sourced.

LoraxHero

Fine-Tuning and Serving LLMs with LoRA

The conventional approach to fine-tuning a deep neural network is to update all the parameters of the model as a continuation of the training process. For large language models with billions of parameters, this requires a massive amount of GPU memory (every trainable parameter amounts to about 4x additional overhead to fine-tune it) and storage waste (tens of gigabytes per model checkpoint). 

To make fine-tuning less resource-hungry, parameter-efficient fine-tuning techniques like Low Rank Adaptation (LoRA) introduce adapters consisting of a small number of new parameters that are trained, while the original model parameters remain frozen. Despite the fact that only a small number of parameters are trained, LoRA achieves performances comparable to full fine-tuning. At serving time, both the original model parameters and the new adapter parameters can be loaded together as a single deployment.

LoRAX Single Deployment

Anatomy of a fine-tuned LLM deployment. The fine-tuned adapter weights (blue) account for only a tiny fraction of the total memory overhead (<10%).

The downside of this is that if multiple models are fine-tuned with LoRA, each needs to be deployed together with the original LLM on a dedicated set of resources, which can quickly add up. While general-purpose LLM APIs like ChatGPT offer dirt-cheap per-token pricing, the same cannot be said for serving fine-tuned LLMs, which conventionally require standing up a dedicated deployment. Doing this yourself in AWS with on-demand pricing for a g5.2xlarge to serve a custom llama-2-7b model will cost you $1.21 per hour, or about $900 per month to serve 24x7. Assuming you have tens to hundreds of fine-tuned LLMs to serve, your cloud bill soon balloons to tens of thousands of dollars per month, regardless of how often you’re querying the LLM service.

LoRAX Multiple Deployments

Naive approach to serving multiple fine-tuned models. User 3’s model (green) is just being provisioned, resulting in a full deployment of the entire serving stack.

While a dedicated deployment per fine-tuned model is operationally simple to implement, it’s far from optimal. As shown in the figures above, the majority of the overhead comes from the base model parameters, which are identical between each deployment. The part of the deployment that is unique to the fine-tuned model – the adapter weights – accounts for less than 10% of the total parameters, far below the GPU memory capacity in most cases.

This all raises the question: what if we could pack multiple fine-tuned models into a single deployment by reusing the common base model parameters?

Introducing LoRA Exchange (LoRAX)

LoRA Exchange (LoRAX) is a new approach to LLM serving infrastructure specifically designed for serving many fine-tuned models at once using a shared set of GPU resources. Compared with conventional dedicated LLM deployments, LoRAX consists of three novel components:

  1. Dynamic Adapter Loading, allowing each set of fine-tuned LoRA weights to be loaded from storage just-in-time as requests come in at runtime, without blocking concurrent requests.
  2. Tiered Weight Caching, to support fast exchanging of LoRA adapters between requests, and offloading of adapter weights to CPU and disk to avoid out-of-memory errors.
  3. Continuous Multi-Adapter Batching, a fair scheduling policy for optimizing aggregate throughput of the system that extends the popular continuous batching strategy to work across multiple sets of LoRA adapters in parallel.

Dynamic Adapter Loading

Unlike conventional serving infrastructure that preloads all model weights during initialization, LoRAX only loads the pretrained base LLM weights during initialization, and dynamically loads each set of fine-tuned LoRA adapters just-in-time at runtime.

LoRAX DAL

Overview of dynamic adapter loading with multiple concurrent fine-tuned model requests. User 3’s model (green) is loaded in the background while the other requests proceed as usual.

To avoid blocking ongoing requests from other users, the LoRAX system maintains an individual request queue per fine-tuned adapter. While the new fine-tuned model’s adapter weights are being dynamically loaded in, all of its associated requests will wait in queue while other requests proceed as usual. In practice, we’ve observed the overhead of dynamically loading in a new adapter to be on the order of 200ms, much less than the typical text generation response time, making it possible to start interactively evaluating your fine-tuned models immediately after training completes. This small loading cost is further amortized away as subsequent tokens are generated and requests to the same model are submitted.

Tiered Weight Caching

As more fine-tuned models are loaded into a single LLM deployment, the memory overhead increases. To avoid encountering an out of memory error (OOM), the LoRAX system implements a tiered weight caching strategy that offloads adapter weights from GPU → CPU → disk to trade off between adapter exchange latency and memory overhead.

LoRAX’s caching strategy aims to strike a balance between keeping as many adapters on the GPU as possible, while leaving enough room for handling long sequences and large request batches. When an adapter does need to be evicted from the GPU, we transition it to CPU (host memory) using a least-recently used (LRU) policy. This policy also extends down to lower layers of the cache, so that we evict weights from CPU and (if necessary) delete them from local ephemeral disk as new adapter weights are loaded in. In the worst case, weights can be redownloaded from object storage.

Putting all of this together allows you to pack upwards of 100 models into a single deployment without the need to scale up to additional replicas barring high request volume.

Continuous Multi-Adapter Batching

One of the most important techniques for enabling high throughput text generation has been continuous batching, whereby multiple requests can be dynamically batched together between each token generation step as new requests come in and old requests complete. 

This presents a challenge when exchanging LoRA weights between requests. If only one set of adapter weights can be in use at a time, then any requests coming in for a different fine-tuned model will either need to wait until all the active requests for a particular adapter complete (worsening latency for the inactive adapters) or consistently swap between adapters every few steps (worsening throughput by negating the effects of continuous batching). 

LoRAX implements a fair scheduling policy that optimizes for aggregate throughput while ensuring liveness for each distinct adapter being requested: 

  1. At any given time, some number of adapters N (limited by GPU memory) will be marked as “active”, with their weights loaded onto the GPU and available for use during decoding.
  2. Requests from activate adapters will be drained from their respective queues and batched together continuously. A simple mask ensures that the correct adapter is applied to each request in the batch when computing the activations for each layer (see figure below).
  3. After a configurable amount of time has elapsed, the scheduling system will move to the next set of adapters in a round robin fashion. In practice, this means the adapter that has been in the active set the longest will be evicted, and the adapter with non-empty request queue that has been waiting the longest will become active. The amount of time to wait before exchanging active adapters can be increased to prioritize throughput, or decreased to prioritize latency.
LoRAX Mask

Decoding using multiple adapters within a single batch. Masks ensure that only the right adapter is used for processing each element of the batch.

Using LoRAX to Fine-Tune and Serve LLaMA-2-7b

Now that we’ve established the core ideas behind LoRAX that make it an efficient serving strategy for fine-tuned LLMs, let’s walk through an example showing how it can be used in conjunction with Predibase’s managed fine-tuning infrastructure to rapidly serve and prompt your fine-tuned LLMs.

Predibase: Infrastructure for Open Source AI

The Predibase platform provides a unified infrastructure system for building and serving fine-tuned LLMs specialized to your tasks. Predibase is built on top of the open source Ludwig framework developed by Uber AI. Ludwig’s declarative interface acts as the glue that binds the fine-tuning and serving systems together, allowing you to build custom LLMs and start prompting them with just a couple lines of code, without losing the flexibility or control provided by lower level frameworks.

Predibase builds on the declarative foundations from Ludwig to abstract away the complexity of managing a production LLM platform. Predibase automatically determines the right compute resources needed for your training and serving jobs, optimizes resource utilization to prevent OOMs and other failure events, and orchestrates the entire lifecycle of your jobs with reliability and fault tolerance out of the box.

LoRAX Architecture

Overview of Predibase’s hybrid infrastructure, separating orchestration and metadata management from model fine-tuning and serving.

Predibase implements a hybrid architecture that efficiently schedules jobs to run in different environments based on a variety of factors including:

  1. Security and privacy: run in Predibase’s managed cloud or in your VPC.
  2. Data locality: run in the same cloud region as your data warehouse to minimize transfer costs.
  3. GPU availability: run wherever the right GPUs can be found at a reasonable price, including Predibase’s own dedicated clusters.

LoRAX with the Predibase Python SDK

Let’s walk through an example showing how to fine-tune and query LLaMA-2-7b using LoRAX through the Predibase Python SDK. If you want to follow along with each step, check out the accompanying Colab notebook here.

First, navigate here to get started with Predibase for free for 14 days. For this tutorial, we’re going to assume you’re running in Predibase’s managed cloud. Once you’re logged-in, navigate to the Settings page to obtain an API token, which will allow you to start running jobs via the Python SDK.

After running pip install -U predibase to install the Python SDK, you need only log in and then you’re ready to start fine-tuning LLMs:

# First-time setup to provide your API token
$ pbase login
API Token: ...

🚀 Welcome to Predibase, 'username'!

Next, open up a Python notebook or interpreter and create a PredibaseClient. This will be the main entrypoint for submitting jobs to Predibase:

from predibase import PredibaseClient

pc = PredibaseClient()

Predibase Cloud comes with a number of popular open source LLMs available to prompt out-of-the-box. These LLMs are shared across tenants, so response times will vary based on load, but paid users also have the option of creating dedicated deployments unique to their tenant.

Let’s start by prompting the shared Llama-2-7b-chat model and see if it can generate some Java code for us:

llama2_7b = pc.LLM("pb://deployments/llama-2-7b-chat")
result = llama2_7b.prompt("Write an algorithm in Java to reverse the words in a string.")
print(result.response)

You likely got something like this for the response:

Of course! Here is an algorithm in Java to reverse the words in a string: 
public static void reverseWords(String str { 
// Split the input string into an array of words 
    String[] words = str.split(" "); 
    // Reverse the array of words 
    for (int i = words.length - 1; i >= 0; i--) { 
        System.out.print(words[i] + " "); 
    } 
} 
Please provide the input string for which you would like me to reverse the words.

Not bad, but a bit verbose. What if we tried fine-tuning the base llama-2-7b model to make it generate exactly the code we’re looking for?

Download the CodeAlpaca dataset, which we’ll use to feed the model examples of expected inputs and outputs during fine-tuning:

!wget https://predibase-public-us-west-2.s3.us-west-2.amazonaws.com/datasets/code_alpaca_800.csv

And now let’s fine-tune:

prompt_template = """Below is an instruction that describes a task, paired with an input
    that may provide further context. Write a response that appropriately
    completes the request.

    ### Instruction: {instruction}

    ### Input: {input}

    ### Response:
"""

llm = pc.LLM("hf://meta-llama/Llama-2-7b-hf")
dataset = pc.upload_dataset("code_alpaca_800.csv")
model_ft = llm.finetune(
    prompt_template=prompt_template,
    target="output",
    dataset=dataset,
).get()

We now have a Llama-2-7b model fine-tuned for code generation, but how well does it perform in practice? Let’s use LoRAX to test our model by dynamically loading the fine-tuned adapter weights into the existing Llama-2-7b base model deployment. In Predibase, this is as simple as a few lines of code:

llama2_7b = pc.LLM("pb://deployments/llama-2-7b")
llama2_7b_ft = llama2_7b.with_adapter(model_ft)
result = llama2_7b_ft.prompt(data=dict(
    instruction="Write an algorithm in Java to reverse the words in a string."
    input="The quick brown fox"
))
print(result.response)

Taking a quick look at the response, we see that the fine tuned model is much more precise and to-the-point than the chat model, while still providing correct code:

public static String reverseWords(String str) {
        String[] words = str.split(" ");
        StringBuilder sb = new StringBuilder();
        for (int i = words.length - 1; i >= 0; i--) {
            sb.append(words[i]).append(" ");
        }
        return sb.toString();
    }
    public static void main(String[] args) {
        String str = "the quick brown fox";
        System.out.println(reverseWords(str));
    }
}

And that’s it! You’ve just successfully fine-tuned and served a task-specific LLM in Predibase without spinning up any costly new resources, thanks to LoRAX.

Closing Thoughts

At Predibase, we believe that the future is fine-tuned, specialized LLMs for your tasks. Rather than boil the ocean with hundreds of billions of parameters that enable an LLM to perform generalist tasks like generating French poetry even if you just need a point-of-sales chatbot, smaller and more specialized LLMs are the most cost effective and performant way to put generative AI into production. But today’s LLM serving infrastructure wasn’t built for serving a variety of specialized models. Achieving that vision of a more fine-tuned future means rethinking the serving stack to be fine-tuning first.

LoRAX is the first step towards a truly fine-tuning first platform for building specialized LLMs. In the weeks and months ahead, we’ll be sharing more about how we’re expanding on this foundation to leverage fine-tuning for such things as:

  • Speeding up inference with smarter, task-specific decoding.
  • Extending the context length of models to handle very long input sequences.
  • Fine-tuning with as few as 10 to 100 examples.

If you’re interested in fine-tuning and serving specialized LLMs, be sure to check out Predibase and sign up for a 2-week free trial to get started fine-tuning for free. And make sure to check out the open-source Ludwig project and join the community.

Acknowledgements

Thanks to Magdy Saleh, Jeffrey Tang, Wael Abid, Arnav Garg, Noah Yoshida, Julian Bright, and Piero Molino for their contributions to this work.

Related Articles

Join Our Community!