Evan Sadler
Evan Sadler
No items found.

Introducing the FlyteCallback for Hugging Face

Hugging Face and Flyte share a common goal: creating a developer-friendly environment, promoting an open-source ethos and fostering a vibrant community. At the same time, their strengths lie in distinct, complementary areas. 

To appreciate this dynamic, consider their mission statements:

Build, train and deploy state-of-the-art models powered by the reference open source in machine learning. — Hugging Face

The infinitely scalable and flexible workflow orchestration platform that seamlessly unifies data, ML and analytics stacks. — Flyte

Together, Hugging Face and Flyte offer developers a blend of simplicity and potency. Hugging Face simplifies handling data and models; Flyte streamlines infrastructure and orchestration. Even tasks like fine-tuning Large Language Models (LLMs) on multiple GPUs become surprisingly straightforward when creators use both platforms.

Until now, however, they haven’t been able to overcome a significant challenge of GPU training, particularly for LLMs and other large models: the hefty price tag. Spot instances can provide a cost-efficient solution, but they don’t address the complexity of managing this infrastructure.

That's where our practical solution comes in: the newly introduced FlyteCallback for Hugging Face's Trainer. Specifically developed to alleviate these complexities, FlyteCallback seamlessly integrates spot instances into your existing Hugging Face workflows. It presents a direct and practical approach to balancing cost efficiency and usability.

Before diving into the FlyteCallback, let's briskly touch on the key building blocks: Flyte checkpoints, Flyte Decks, and the Hugging Face Trainer. Understanding these elements will provide a solid foundation for appreciating the functionality and benefits of the FlyteCallback.

Flyte checkpoints

Each Flyte task comes with its own checkpoint. For users, reading from or writing to these checkpoints feels as familiar as dealing with a local file. Behind the scenes, Flyte manages the more complex process of storing these checkpoints in the cloud and retrieving them when necessary. Learn more about them here.

Copied to clipboard!
from flytekit import current_context, task

@task
def train_model(num_epochs: int) -> int:
    cp = current_context().checkpoint
    ## load previous progress
    prev = cp.read()
    start = 0 if prev is None else int(prev.decode())
    ## save progress
    for i in range(start, num_epochs):
        cp.write(f"{i + 1}".encode())
    return i

Flyte Decks

By setting `disable_deck=False` in the task decorator, a Flyte Deck is added to a task. Visible from the user interface, Decks provide an excellent space for displaying charts and reports. Flytekit comes pre-equipped with renderers for specific types like DataFrames, Markdown, images and more. For those who prefer more control, there's always the option to use custom HTML.

Copied to clipboard!
import flytekit 
from flytekit import task
from flytekitplugins.deck.renderer import TableRenderer
import pandas as pd

@task(disable_deck=False)
def t1(df: pd.DataFrame):
    flytekit.Deck("DataFrame", TableRenderer().to_html(df))
    flytekit.Deck("Hello", “Hello World!”)

Hugging Face Trainer

The Hugging Face Trainer simplifies the management of model training life cycles, offering a convenient means of data preprocessing based on widespread deep-learning patterns. It also supports a wide variety of training strategies, including multi-GPU and mixed-precision training, making it versatile for different use-cases. Its extensive nature allows for great flexibility, so exploring its breadth of features can be beneficial (learn more here).

Copied to clipboard!
from transformers import Trainer

training_args = TrainingArguments("trainer", num_train_epochs=2)
trainer = Trainer(
    model,
    training_args,
    save_strategy="epoch",
    train_dataset=datasets["train"],
    eval_dataset=datasets["validation"]
)

Putting them together

The callback integrates the Hugging Face Trainer into Flyte’s checkpointing system and Flyte Decks. Now, training logs are visible in the UI — and more excitingly, we can easily resume training after a random failure or a failure that occurs when using spot instances. 

We wanted to make sure training Hugging Face models felt like a natural extension to Flyte. Thankfully, the Trainer supports callbacks for this very reason. Even better, Hugging Face supports a set of community owned callbacks built into the transformers package.

Besides feeling special being able to contribute to a library like transformers, the results are pretty great. Checkpointing, spot instances and reporting requires only minor adjustments. See the example below:

Copied to clipboard!
# Note: This example skips over some setup steps for brevity.
from flytekit import current_context, task, Resources
from transformers import Trainer
from transformers.integrations import FlyteCallback


@task(resources=Resources(gpu="1"), disable_deck=False, interruptible=True)
def train_hf_transformer():
    cp = current_context().checkpoint
    training_args = TrainingArguments(
        "trainer",
        save_strategy="epoch",
        num_train_epochs=100
    )
    trainer = Trainer(model,training_args, callbacks=[FlyteCallback()])
    trainer.train(resume_from_checkpoint=cp.restore())

Developing the FlyteCallback was a great experience that showcased the openness and extensibility of both Flyte and Hugging Face. I really appreciate the supportive feedback and enthusiasm from both teams. The entire process was wrapped up within a week, a timeline that speaks to the ease of collaboration.

Looking ahead, we're excited about two promising features: asynchronous checkpoints and real-time Decks. Today, checkpointing halts task execution, wasting compute resources. A bottleneck particularly notable when dealing with large models. Asynchronous checkpoints promise to alleviate this by becoming non-blocking.  Similarly, while current Flyte Decks only surface once a task concludes, real-time Decks will provide updates mid-execution - very important for long training and scoring procedures.

Stay tuned for more updates, and happy coding!