Samhita Alla
Samhita Alla

How to Serve ML Models with Banana

Orchestrate your ML training pipeline and serve your model at scale in 10 minutes or less.

Before you train an ML model requires you to procure, prepare and analyze the data you’ll use to build and train it. When there’s additional data, you may need to repeat the steps multiple times — a process known as “retraining.” To ensure reproducibility, it may be necessary to version a retrained model. Additionally, in order to keep the execution cost-effective, it may be necessary to cache the outputs of the model. To complete the loop, the model must be deployed so that end users can generate predictions. 

I set out on a journey to build an application that would orchestrate a fine-tuned pipeline and enable on-demand, scalable model serving. My approach:

  • Build a fine-tuning pipeline leveraging 🤗 Hugging Face transformers.
  • Allocate additional memory for resource-intensive operations and a GPU to expedite model training.
  • Cache the dataset to reduce the execution time.
  • Upon completion of training, push the trained model to 🤗 Hub.
  • Enable the user to approve or reject the model deployment.
  • If approved, retrieve the model from the Hub and deploy it.
  • Invoke the endpoint to generate predictions.

In addition, I wanted to ensure that the end-to-end pipeline maintained data lineage and versioning.

In the following sections, I'll take you through a code walkthrough of the Flyte x Banana integration. Let's dive in!

Note: Do you want to jump directly to the code? It’s on GitHub.

⛓️ The fine-tuning pipeline

Let’s use a `bert-base-uncased` pre-trained model and fine-tune it on the Yelp dataset. We can break the pipeline down into four steps:

  1. Downloading the dataset
  2. Tokenizing the dataset
  3. Splitting the dataset into training and evaluation subsets
  4. Training an ML model using the subsets.

Let’s use Flytekit – Flyte’s Python SDK — to build the pipeline.

📥 Downloading the dataset

Download the `yelp_review_full` dataset and store it in a directory. Use FlyteDirectory, a custom Flyte type that facilitates smooth communication between Flyte tasks, to enable automatic uploading and downloading of directories. Configure the resources to guarantee that the task has adequate resources to acquire the dataset. Enable faster execution by caching the task output.

Copied to clipboard!
@task(
    cache=True,
    cache_version="1.0",
    requests=Resources(mem="1Gi", cpu="2", ephemeral_storage=ephemeral_storage),
)
def download_dataset() -> FlyteDirectory:
    dataset = load_dataset("yelp_review_full")

    local_dir = create_local_dir(dir_name="yelp_data")

    dataset.save_to_disk(dataset_dict_path=local_dir)
    return FlyteDirectory(path=str(local_dir))

🧹 Tokenizing the data

Load a predefined `bert-base-uncased` tokenizer, apply it to the Yelp dataset, and store the resulting output in a directory. In addition, cache the tokenized data and allocate the necessary resources for the task. This will improve the task's execution speed and efficiency.

Copied to clipboard!
@task(
    cache=True,
    cache_version="1.0",
    requests=Resources(mem="1Gi", cpu="2", ephemeral_storage=ephemeral_storage),
)
def tokenize(dataset: FlyteDirectory) -> FlyteDirectory:
    downloaded_path = dataset.download()
    loaded_dataset = load_from_disk(downloaded_path)
    tokenized_dataset = loaded_dataset.map(tokenize_function, batched=True)

    local_dir = create_local_dir(dir_name="yelp_tokenized_data")

    tokenized_dataset.save_to_disk(dataset_dict_path=local_dir)
    return FlyteDirectory(path=str(local_dir))

🖖 Splitting the dataset

Divide the dataset into two separate subsets for training and evaluation, each with the `datasets.Dataset` type, which is already defined in the `flytekitplugins-huggingface` library.

Copied to clipboard!
@task(requests=Resources(mem="1Gi", cpu="2", ephemeral_storage=ephemeral_storage))
def get_train_eval(
    tokenized_dataset: FlyteDirectory,
) -> datasets_tuple:
    downloaded_path = tokenized_dataset.download()
    loaded_tokenized_dataset = load_from_disk(downloaded_path)

    small_train_dataset = (
        loaded_tokenized_dataset["train"].shuffle(seed=42).select(range(dataset_size))
    )
    small_eval_dataset = (
        loaded_tokenized_dataset["test"].shuffle(seed=42).select(range(dataset_size))
    )
    return datasets_tuple(
        train_dataset=StructuredDataset(dataframe=small_train_dataset),
        eval_dataset=StructuredDataset(dataframe=small_eval_dataset),
    )

🤖 Fine-tuning BERT model

Allocate a GPU to the train task to accelerate fine-tuning. Initialize `secret_requests` to retrieve the 🤗 token. Load the pre-trained `bert-base-uncased` model, initialize `TrainingArguments`, and use `Trainer`to train the model using tokenized data. Finally, publish the trained model to 🤗 hub, and return the GitHub SHA of the published model, Flyte execution ID (to maintain lineage), and repository name.

Note: To store secrets in Flyte, you can utilize a secrets management system, with Kubernetes secrets being the default. In the following section, we’ll see how to create a secret.

Copied to clipboard!
@task(
    requests=Resources(mem=mem, cpu="2", gpu=gpu, ephemeral_storage=ephemeral_storage),
    secret_requests=[
        Secret(
            group=HF_SECRET_GROUP,
            key=HF_SECRET_NAME,
            mount_requirement=Secret.MountType.FILE,
        )
    ],
)
def train(
    small_train_df: StructuredDataset, small_eval_df: StructuredDataset, hf_user: str
) -> dict:
    HUGGING_FACE_HUB_TOKEN = flytekit.current_context().secrets.get(
        HF_SECRET_GROUP, HF_SECRET_NAME
    )
    repo = f"{hf_user}/{HUGGINGFACE_REPO}"
    execution_id = flytekit.current_context().execution_id.name
    small_train_dataset = small_train_df.open(Dataset).all()
    small_eval_dataset = small_eval_df.open(Dataset).all()

    model = AutoModelForSequenceClassification.from_pretrained(
        "bert-base-uncased", num_labels=5
    )

    training_args = TrainingArguments(
        output_dir=HUGGINGFACE_REPO,
        evaluation_strategy="epoch",
        push_to_hub=True,
        hub_token=HUGGING_FACE_HUB_TOKEN,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=small_train_dataset,
        eval_dataset=small_eval_dataset,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    trainer.push_to_hub(
        commit_message=f"End of training - Flyte execution ID {execution_id}"
    )
    return {
        "sha": model_info(repo).sha,
        "execution_id": execution_id,
        "repo": repo,
    }

🗄️ Sending model metadata to GitHub

Because Banana must be deployed through GitHub, locate all Banana deployment files in a GitHub repository. The deployment is triggered by a push event once a Banana GitHub action is included in the appropriate repository.

Initiating Banana deployment from within a Flyte workflow requires triggering a push event in this instance, we will push the model metadata.

Take the model metadata retrieved from the `train` task and add it to a `model_metadata.json` file. When transmitting the data through the GitHub API, it must be converted to base64 encoding. Prior to generating the commit, obtain the most recent commit SHA to be used when sending the push event. Utilize the subprocess library to hit the `createCommitOnBranch` endpoint.

Copied to clipboard!
@task(
    secret_requests=[
        Secret(
            group=SECRET_GROUP, key=SECRET_NAME, mount_requirement=Secret.MountType.FILE
        )
    ],
    requests=Resources(mem="1Gi", ephemeral_storage=ephemeral_storage),
)
def push_to_github(
    model_metadata: dict, gh_owner: str, gh_repo: str, gh_branch: str
) -> str:
    token = flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_NAME)

    additions = [
        {
            "path": f"banana/model_metadata.json",
            "contents": f"{base64.urlsafe_b64encode(json.dumps(model_metadata).encode()).decode('utf-8')}",
        }
    ]

    sha_result = subprocess.run(
        f"""
        curl -s -H "Authorization: bearer {token}" \
        -H "Accept: application/vnd.github.VERSION.sha" \
        "https://api.github.com/repos/{gh_owner}/{gh_repo}/commits/main"
        """,
        shell=True,
        capture_output=True,
        text=True,
    )

    github_sha = sha_result.stdout

    cmd = f"""
curl https://api.github.com/graphql -s -H "Authorization: bearer {token}" --data @- << GRAPHQL | jq '.data.createCommitOnBranch.commit.url[0:56]'
{{
   "query": "mutation (\$input: CreateCommitOnBranchInput!) {{
   createCommitOnBranch(input: \$input) {{
       commit {{
         url
       }}
   }}
   }}",
   "variables": {{
     "input": {{
       "branch": {{
         "repositoryNameWithOwner": "{gh_owner}/{gh_repo}",
         "branchName": "{gh_branch}"
       }},
       "message": {{
       "headline": "Update the model artifact"
       }},
       "fileChanges": {{
         "additions": {json.dumps(additions)}
       }},
       "expectedHeadOid": "{github_sha}"
     }}
   }}
}}
GRAPHQL
"""

    result = subprocess.run(
        cmd,
        shell=True,
        capture_output=True,
        text=True,
    )

    return result.stdout

Refer to the end-to-end pipeline available on the GitHub repository for a comprehensive overview.

▶️ Running the pipeline locally

In order to execute the model locally, you must first store the secrets in a local file. 

  • Obtain access tokens for HuggingFace and GitHub and store them in files within a secrets directory, as follows:
Copied to clipboard!
<your-secrets-dir>/deployments-secrets/flyte-banana-creds
<your-secrets-dir>/hf-secrets/flyte-banana-hf-creds
  • Within a `.env` file, set the following two variables to enable local code execution:
Copied to clipboard!
FLYTE_SECRETS_DEFAULT_DIR=<your-secrets-dir>
DEMO="1"

Next, install the necessary requirements in a virtual environment.

To run the Flyte workflow, use the following command:

Copied to clipboard!
pyflyte run ml_pipeline.py --gh_owner <your-github-username> --gh_repo <your-github-repo> --gh_branch <your-github-repo-branch> --hf_user <your-huggingface-username>

The pipeline retrieves the data, tokenizes it and trains a model. Then it waits for the user to approve the push event before pushing the model metadata to GitHub, which would trigger a Banana deployment.

🍌 Serving on Banana

Banana is an ML inference solution that does inference on serverless GPUs. 

🛠️ Setting it up

To activate the Banana deployment, you need to:

  1. Create an account.
  2. Configure your Banana account by creating a deployment and linking it to the forked GitHub repository. To do this, navigate to Team > Integrations > GitHub > Manage Repos and Deploy > Deploy from GitHub > your repository.
  3. That's all! Every push to the configured GitHub repository will now trigger a deployment.

🪚 Adding code for inference

The Banana inference code needs to be encapsulated in the `app.py` and `server.py` files.

In the `app.py` file, the following steps should be taken:

  • Define an `init()` method. This should involve fetching the model from the 🤗 hub and loading it onto a GPU.
  • Define an `inference()` method. This should involve tokenizing the user-given prompt and generating a prediction.
Copied to clipboard!
import json

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer


# Init is ran on server startup
# Load your model to GPU as a global variable here using the variable name "model"
def init():
    global model

    with open("model_metadata.json") as f:
        model_metadata = json.load(f)

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    model = AutoModelForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path=model_metadata["repo"],
        num_labels=5,
        revision=model_metadata["sha"],
    ).to(device)


# Inference is ran for every server call
# Reference your preloaded global model variable here.
def inference(model_inputs: dict) -> dict:
    global model

    # Parse out your arguments
    prompt = model_inputs.get("prompt", None)
    if prompt == None:
        return {"message": "No prompt provided"}

    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    encoding = tokenizer(
        prompt, padding="max_length", truncation=True, return_tensors="pt"
    ).to(device)

    # Run the model
    outputs = model(**encoding)
    prediction = outputs.logits.argmax(-1)

    # Return the result as a dictionary
    return {"result": prediction.item()}

Refer to the `server.py` file code on GitHub.

🧪 Testing the endpoint

To run the Banana server locally, execute the command `python server.py`. Confirm that the Banana API endpoint is functional by running the following test case:

Copied to clipboard!
import argparse

import requests


def generate_predictions(args):
    model_inputs = {"prompt": args.prompt}
    res = requests.post(args.url, json=model_inputs)
    print(res.json())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--url", default="http://localhost:8000/", type=str, help="API endpoint URL"
    )
    parser.add_argument(
        "--prompt",
        default="The service is terrible, the staff seem to be generally clueless, the management is inclined to blame the staff for their own mistakes, and there's no sense of FAST in their fast food.",
        type=str,
        help="Prompt",
    )
    args = parser.parse_args()
    generate_predictions(args=args)

Note: Install the Banana requirements prior to running the command.

📤 Deploying the model

To deploy the model on Banana, prepare a Dockerfile and place it at the root of your GitHub repository:

Copied to clipboard!
#####################
# BANANA DOCKERFILE #
#####################

# Must use cuda version 11+
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime

WORKDIR /root

# Install git
RUN apt-get update && apt-get install -y git

# Install python packages
RUN pip3 install --upgrade pip
ADD banana/requirements.txt requirements.txt
RUN pip3 install -r requirements.txt

# We add the banana boilerplate here
ADD banana/server.py /root

# Add your custom app code, init() and inference()
ADD banana/app.py /root

# Add model metadata
ADD banana/model_metadata.json /root

EXPOSE 8000

CMD python3 -u server.py

Once the files are in place, run the Flyte pipeline locally again. Approve the model metadata push to GitHub, and the model should be built and deployed on Banana!

🫐 Running the pipeline on Flyte cluster

Create a Dockerfile that includes the necessary requirements to package and register your Flyte workflow.

Copied to clipboard!
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime

WORKDIR /root
ENV VENV /opt/venv
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8
ENV PYTHONPATH /root

RUN apt-get update && apt-get install -y build-essential git-lfs curl
RUN pip3 install awscli

ENV VENV /opt/venv

# Virtual environment
RUN python3 -m venv ${VENV}
ENV PATH="${VENV}/bin:$PATH"

# Install Python dependencies
COPY ./requirements.txt /root
RUN pip install -r /root/requirements.txt

COPY workflows /root/workflows

# This tag is supplied by the build script and will be used to determine the version
# when registering tasks, workflows, and launch plans
ARG tag
ENV FLYTE_INTERNAL_IMAGE $tag

Create Kubernetes secrets to store GitHub and HuggingFace tokens as follows:

Copied to clipboard!
kubectl create secret generic deployment-secrets --namespace flytesnacks-development --from-file=flyte-banana-creds=<your-secrets-dir>/deployment-secrets/flyte-banana-creds

kubectl create secret generic hf-secrets --namespace flytesnacks-development --from-file=flyte-banana-hf-creds=<your-secrets-dir>/hf-secrets/flyte-banana-hf-creds

To run the Flyte workflow on an actual Flyte backend, set up a Flyte cluster. The simplest way to get started is by running `flytectl demo start` command, which spins up a mini-replica of the Flyte deployment.

Register tasks and workflows using the following command, which can leverage the docker registry included with the demo cluster for image pushing and pulling:

Copied to clipboard!
pyflyte register --image <flyte-docker-image> ml_pipeline.py

And then, launch the registered workflow on the UI. To deploy the retrained model on Banana, click "Approve." This action saves the model metadata in the GitHub repository, and the push action triggers a deployment on Banana. 

Execution graph in the Flyte UI

🧪 Testing the Banana deployment

To test your deployment, retrieve the API and model keys. Store these keys in your local environment, and then execute the following test:

Copied to clipboard!
import argparse
import os

import banana_dev as banana
from dotenv import load_dotenv

load_dotenv()

api_key = os.getenv("BANANA_API_KEY")
model_key = os.getenv("BANANA_MODEL_KEY")


def generate_predictions(args):
    print(banana.run(api_key, model_key, {"prompt": args.prompt})["modelOutputs"][0])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--prompt",
        default="The service is terrible, the staff seem to be generally clueless, the management is inclined to blame the staff for their own mistakes, and there's no sense of FAST in their fast food.",
        type=str,
        help="Prompt",
    )
    args = parser.parse_args()
    generate_predictions(args=args)

Upon running the code, you should observe a predicted label being returned.

🥡 Takeaways

Hopefully, you’ve learned a bunch about Flyte and Banana! Orchestration and inference can go hand-in-hand seamlessly as demonstrated in this piece. You can ensure your ML pipelines are versioned, cached and reproducible, and at the same time, run online inference at scale on GPUs. 

Here are some key takeaways from this application:

  • Ensure coherence between retraining ML models and deployment
  • Human-in-the-loop can power your deployment with Flyte orchestration
  • GPU-powered serverless inference with Banana
  • Comprehensive data lineage across the entire model development and deployment pipelines for easier debugging
  • Every model is versioned through 🤗 hub
  • Flyte versioning ensures versioned ML pipeline executions

Flyte and Banana offer the potential to create production-grade ML pipelines with ease. If this resonates with your needs, I encourage you to try these tools.

Thanks for reading this far! I hope you found this helpful. Don’t forget to give the Flyte repository a star! If you have any questions, ask them in Flyte Slack.