How to Run Inference on Ludwig Models Using TorchScript

December 5, 2022 · 6 min read
Ludwig Inference with TorchScript
Geoffrey Angus
Geoffrey Angus

Introduction

In Ludwig 0.6, we have introduced the ability to export Ludwig models into TorchScript, making it easier than ever to deploy models for highly performant model inference. In this blog post, we will describe the benefits of serving models using TorchScript and demonstrate how to train, export, and use the exported models on an example dataset.

How are models served today?

A common way to serve machine learning models is wrapping them in REST APIs and exposing their endpoints. In fact, you can do this today in Ludwig using the serve command:

ludwig serve --model_path=<MODEL_PATH>

This works great if you do not have particularly strict SLA requirements or if backwards compatibility is not a concern. However, if you need to serve a model in a production environment, you will likely need to use a more robust solution. Such a solution typically involves one or both of (1) some inference framework, such as TorchServe or NVIDIA Triton, and (2) a method of serializing models into inference-optimized formats such as TorchScript or ONNX. As of Ludwig 0.6, we now support the export of Ludwig models into TorchScript.

What is TorchScript?

TorchScript is a format that allows you to serialize a model and its weights in a way that can be loaded and executed in a production environment.

Serving TorchScript models has numerous advantages over serving vanilla PyTorch models:

  • TorchScript models are fully serialized and can be loaded without the original Python code. This means that you can load TorchScript models in a lightweight environment with minimal dependencies (see the Try it Yourself section below).
  • TorchScript models are often more performant than their vanilla PyTorch counterparts. This is because the underlying PyTorch computation graph is extracted and optimized for inference.
  • Finally, TorchScript models can be loaded and executed in a variety of environments, including C++. This means that you can use TorchScript models in a variety of production environments, including mobile devices.

What is new in Ludwig?

Vanilla Ludwig models are composed of three stages, (1) preprocessing, (2) prediction, and (3) postprocessing. While prediction is always done in PyTorch, preprocessing and postprocessing happens using DataFrame engines such as either Pandas or Dask.

Ludwig three stage process for model inference and deployment with Dask and Pytorch

Vanilla Ludwig models are composed of three stages, (1) preprocessing, (2) prediction, and (3) postprocessing. While prediction is always done in PyTorch, preprocessing and postprocessing happens using DataFrame engines such as either Pandas or Dask.

In Ludwig 0.6, we introduced the export_torchscript CLI command, which takes a trained vanilla Ludwig model and exports it to the TorchScript format. In order to support this, we needed to ensure that everything from preprocessing to postprocessing was TorchScript-compatible in order to give Ludwig users the end-to-end inference experience they expect from a trained Ludwig model. This meant writing preprocessors and postprocessors for each feature type as torch.nn.Module subclasses, and introducing a new InferenceModule class able to execute the entire graph. Behind the scenes, the InferenceModule is composed of three separate TorchScript modules: the preprocessor, the predictor, and the postprocessor.

Ludwig three stage process for model inference and deployment with Pytorch Torchscript

The three-stage pipeline design gives us several benefits: Having separate TorchScript modules makes it possible to split these stages between devices– for example, placing the preprocessor and postprocessor on CPU, while placing the predictor on GPU. A single monolithic TorchScript module would make doing this much more difficult. If the staged pipeline is paired with a model serving tool like NVIDIA Triton, it becomes possible to independently scale the stages to ensure maximum efficiency. If interested in using NVIDIA Triton in this way, check out the export_triton CLI command.

The three-stage pipeline design gives us several benefits:

  • Having separate TorchScript modules makes it possible to split these stages between devices– for example, placing the preprocessor and postprocessor on CPU, while placing the predictor on GPU. A single monolithic TorchScript module would make doing this much more difficult.
  • If the staged pipeline is paired with a model serving tool like NVIDIA Triton, it becomes possible to independently scale the stages to ensure maximum efficiency. If interested in using NVIDIA Triton in this way, check out the export_triton CLI command.

With this change, Ludwig users can now export trained models and deploy them seamlessly* in a variety of production environments. Next, we’ll take a look at exactly how to do so!

*For the most part– see the Limitations section below.

Try It Yourself

For this blog post, we are going to be working with the Twitter Bots dataset– a popular public dataset where the task is identifying bots based on account metadata. The next section will demonstrate how to

  1. Train a multimodal (tabular+text) neural network pipeline with Ludwig
  2. Export the entire neural network pipeline into TorchScript
  3. Run inference on the exported TorchScript model

Training a Ludwig Model

Ludwig is a Declarative ML framework that enables training multimodal, multi-task Deep Learning models with simple YAML configuration files. It also comes with a couple of example datasets to help new users get started.

To get started, download the Twitter Bots dataset via the Ludwig CLI with the following command:

ludwig datasets download --output_dir=<DATA_DIRECTORY> twitter_bots

Once the dataset is downloaded, all we have to do to train a model is specify a new YAML configuration file that looks like the following:

input_features:
  - name: description
    type: text
    column: description
  - name: default_profile
    type: binary
    column: default_profile
  - name: default_profile_image
    type: binary
    column: default_profile_image
  - name: favourites_count
    type: number
    column: favourites_count
  - name: followers_count
    type: number
    column: followers_count
  - name: friends_count
    type: number
    column: friends_count
  - name: verified
    type: binary
    column: verified
  - name: average_tweets_per_day
    type: number
    column: average_tweets_per_day
  - name: account_age_days
    type: number
    column: account_age_days
output_features:
  - name: account_type
    type: binary
    column: account_type
model_type: ecd
trainer:
    epochs: 1

An example Ludwig config for training a Twitter Bots model

The above configuration file specifies a multimodal neural network that takes in 9 input features (8 tabular and 1 text) and outputs a single binary output feature. Each of the features in this case are columns from the downloaded Twitter Bots Parquet file. The model type is set to ecd, which is the ECD architecture that Ludwig uses to train multimodal neural networks.

Finally, we set the epoch count to 1 for the sake of expediency in this blog post. If you are interested in training a performant model, feel free to set this value higher!

We can finally train a Ludwig model using the above configuration with the following command:

ludwig train --config=<CONFIG_FILEPATH> \ --dataset=<DATA_DIRECTORY>/twitter_bots.parquet \ --output_directory=<EXPERIMENT_DIRECTORY>

After a few minutes, you should have a trained Ludwig model.

Exporting a Ludwig Model to TorchScript

We can now export our trained Ludwig model into TorchScript. We can do this by running the following command:

ludwig export_torchscript --model_path=<EXPERIMENT_DIRECTORY>/model \ --output_path=<EXPERIMENT_DIRECTORY>/torchscript

If model training was done on a CPU device, doing so produces three new TorchScript artifacts and copies over a couple of others for convenience. In the specified output directory, we should now have the following files:

torchscript/ inference_preprocessor.pt inference_predictor_cpu.pt inference_postprocessor.pt training_set_metadata.json # copied over from - model_path model_hyperparameters.json # copied over from - model_path

These artifacts fully represent a single Ludwig model. The three TorchScript files represent our inference pipeline. The inference_preprocessor.pt file is a TorchScript version of the preprocessing steps done during training. The inference_predictor_cpu.pt file is a TorchScript version of the trained model. The inference_postprocessor.pt file is a TorchScript version of the postprocessing steps associated with the output.

Running Inference on a TorchScript Model

The pipelined nature of the exported Ludwig model gives the user a lot of flexibility in how they want to run inference. Below are a couple of examples to do so.

Method 1: Using the InferenceModule class

The most straightforward way to start using the model is by leveraging the InferenceModule class, a convenience class that allows users to load all three modules from a given directory. It additionally provides a predict method whose interface mimics LudwigModel.predict and enables inference on a Pandas DataFrame.

Example:

from pprint import pprint

import pandas as pd
from ludwig.models.inference import InferenceModule

inference_module = InferenceModule.from_directory(
    f"{EXPERIMENT_DIRECTORY}/torchscript")

input_df = pd.read_parquet(f"{DATA_DIRECTORY}/twitter_bots.parquet")
input_sample_df = input_df.head(2)

postproc_output, _ = inference_module.predict(input_sample_df, return_type=dict)
pprint(postproc_output)

An example of how to use the `InferenceModule` class for inference.

Output:

{'account_type': {'predictions': ['human', 'human'], 'probabilities': tensor([[1.0000e+00, 2.1253e-07], [1.0000e+00, 7.5131e-11]])}}

The primary limitation of this method is that it requires a Python backend to run. The subsequent methods use pure TorchScript and thus do not have this limitation.

Method 2: Convert InferenceModule Into a Single TorchScript Module

It is also possible to convert the entire pipeline into a single TorchScript module. This can be done by running torch.jit.script on the InferenceModule class from Method 1. You can then save the resulting module to file and use it in any TorchScript-compatible backend.

The resulting model is pure TorchScript, so input DataFrames must be converted to dictionary objects before the forward pass. You can use the convenience function ludwig.utils.inference_utils.to_inference_module_input_from_dataframe to do this for you.

import json
from pprint import pprint

import pandas as pd
import torch

from ludwig.models.inference import InferenceModule
from ludwig.utils.inference_utils import to_inference_module_input_from_dataframe

inference_module = InferenceModule.from_directory(
    f"{EXPERIMENT_DIRECTORY}/torchscript")

# Convert the InferenceModule into a single TorchScript module
single_module = torch.jit.script(inference_module)

with open(
    f"{EXPERIMENT_DIRECTORY}/torchscript/model_hyperparameters.json") as f:
    config = json.load(f)

input_df = pd.read_parquet(f"{DATA_DIRECTORY}/twitter_bots.parquet")
input_sample_df = input_df.head(2)
input_sample_dict = to_inference_module_input_from_dataframe(
    input_sample_df, config)

postproc_output = single_module(input_sample_dict)
pprint(postproc_output)

An example of converting the `InferenceModule` into a pure TorchScript object.

Output:

{'account_type': {'predictions': ['human', 'human'], 'probabilities': tensor([[1.0000e+00, 2.1253e-07], [1.0000e+00, 7.5131e-11]])}}

Method 3: Running Inference on Each Module Separately

The final way to use this model is by loading the modules individually and forward passing on each one sequentially. One might do this if using an inference framework that can assign resources to each pipeline stage accordingly, such as NVIDIA Triton.

As in Method 2, inputs must be fed in as dictionary objects. Also note that, because these modules are individually optimized for inference frameworks like NVIDIA Triton, the output of the postprocessor is a flattened dictionary (as opposed to the nested dictionary outputs from Methods 1 and 2).

import json
from pprint import pprint

import pandas as pd
import torch

from ludwig.utils.inference_utils import to_inference_module_input_from_dataframe

preprocessor = torch.jit.load(
    f"{EXPERIMENT_DIRECTORY}/torchscript/inference_preprocessor.pt")
predictor = torch.jit.load(
    f"{EXPERIMENT_DIRECTORY}/torchscript/inference_predictor-cpu.pt")
postprocessor = torch.jit.load(
    f"{EXPERIMENT_DIRECTORY}/torchscript/inference_postprocessor.pt")

input_df = pd.read_parquet(f"{DATA_DIRECTORY}/twitter_bots.parquet")
input_sample_df = input_df.head(2)

with open(
    f"{EXPERIMENT_DIRECTORY}/torchscript/model_hyperparameters.json") as f:
    config = json.load(f)

input_sample_dict = to_inference_module_input_from_dataframe(
    input_sample_df, config)

preproc_input = preprocessor(input_sample_dict)
raw_output = predictor(preproc_input)
postproc_output = postprocessor(raw_output)
pprint(postproc_output)

An example of running inference on each TorchScript module separately.

An example of running inference on each TorchScript module separately.

Output:

{'account_type::predictions': ['human', 'human'], 'account_type::probabilities': tensor([[1.0000e+00, 2.1253e-07], [1.0000e+00, 7.5131e-11]])}

Limitations

The TorchScript export code path introduces significant improvements in Ludwig’s model serving capabilities, there are still a few limitations when it comes to exporting models to TorchScript.

As of Ludwig 0.6.4, we are just beginning to provide limited support for HuggingFace encoders; this means that HuggingFace encoders will likely need to be fine-tuned using one of our TorchScript-compatible encoders before use. We have implemented TorchScript-compatible GPT-2, CLIP, and BERT tokenizers for such use cases.

The second is that there are a small number of features that must be featurized before being fed into the inference module. These features are ImageAudio, and DateImage and Audio features– which can often be trained in Ludwig using file paths alone– must be loaded in as Tensors before use with the inference module. Date features– which can be trained in Ludwig using the raw string format– must be preprocessed since the built-in datetime module is not TorchScript-able. For more info, please take a look at the docs. All other feature types supported by Ludwig have been made compatible with the new TorchScript implementation.

Wrap-up

The new TorchScript export feature makes it possible to decouple preprocessing, prediction, and postprocessing in order to serve highly performant models. With this new feature, Ludwig users can deploy their models into production in a way that is backend-independent and highly backwards compatible.

We’d love to hear your feedback on this new feature. Please let us know if you have any questions or suggestions by posting in the Ludwig Slack channel or by opening an issue on GitHub.

If you are interested in going even further, Predibase is an end-to-end Declarative ML platform to train and deploy production-grade machine learning models in just a few lines of code. It is built with Ludwig at its core and supports all of the features described in this post and more, including high-performance serving in NVIDIA Triton at the click of a button. If you are interested in trying it out, feel free to get in touch.

Happy serving!

Related Articles

Join Our Community!