How Flyte is Simplifying the Orchestration of Machine Learning Pipelines
1.0 Flyte Decks
Available from Flyte 1.1.0
Exploratory Data Analysis (EDA) is a mechanism to analyze data using visualization to better understand its intricacies. This significant phase of the data preparation process precedes building the ML models. Flyte Deck helps developers sketch the data visualizations within the Flyte UI.
PyTorch is a powerful and dominant ML framework. One of its notable ( and 🙃) features is manual device conversion — a GPU tensor doesn’t work on a CPU. This conversion has been automated within Flyte with the newly added PyTorch types.
The <span class="code-inline">train</span> task runs on a GPU, whereas the <span class="code-inline">predict</span> task runs on a CPU. Have you noticed the omission of <span class="code-inline">to(torch.device("cpu"))</span> in the <span class="code-inline">predict</span> task? 😅 The conversion happens automatically within the Flyte types.
Note: <span class="code-inline">PyTorchCheckpoint</span> is a special type of checkpoint to serialize and deserialize PyTorch models. It checkpoints <span class="code-inline">torch.nn.Module</span>’s state, hyperparameters and optimizer state as described in the PyTorch best practices recipe.
3.0 native support for ONNX models
Available from Flytekit 1.1.1
ML frameworks converge at one point: ONNX (Stable Diffusion)
from typing import List
import numpy as np
import onnxruntime as rt
import tensorflow as tf
from flytekit import Resources, task
from flytekit.types.file import ONNXFile
from flytekitplugins.onnxtensorflow import TensorFlow2ONNX, TensorFlow2ONNXConfig
from tensorflow.keras import layers, models
from typing_extensions import Annotated
@task(requests=Resources(mem="1000Mi", cpu="2"))
def train(
train_images: np.ndarray, train_labels: np.ndarray
) -> Annotated[
TensorFlow2ONNX,
TensorFlow2ONNXConfig(
input_signature=(
tf.TensorSpec((None, 32, 32, 3),
tf.double, name="input"),
),
opset=13,
),
]:
model = models.Sequential()
model.add(layers.Conv2D(
32,
(3, 3),
activation="relu",
input_shape=(32, 32, 3)
))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation="relu"))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation="relu"))
model.add(layers.Dense(10))
model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True
),
metrics=["accuracy"],
)
model.fit(train_images, train_labels, epochs=2)
return TensorFlow2ONNX(model=model)
@task(requests=Resources(mem="1000Mi", cpu="2"))
def onnx_predict(
model: ONNXFile,
test_images: np.ndarray,
) -> List[np.ndarray]:
m = rt.InferenceSession(
model.download(),
providers=["CPUExecutionProvider"]
)
onnx_pred = m.run(
[n.name for n in m.get_outputs()],
{"input": test_images}
)
return onnx_pred
4.0 Spark Pipelines
Available from Flytekit 1.1.1
Spark is one of the most used integrations in Flyte. To provide support for passing along a Spark pipeline between the Flyte tasks, a Spark ML pipeline type has been added.
The <span class="code-inline">PipelineModel</span> can now be serialized and deserialized as any other Flyte type.
5.0 whylogs Integration
Available from Flytekit 1.1.1
whylogs is an open-source data and ML models logging library. It creates statistical summaries of datasets to track changes in them, create data constraints, and visualize key summary statistics. whylogs can be used from within Flyte using the newly added integration.
Copied to clipboard!
pip install flytekitplugins-whylogs
Running this example with modified constraints generates the following reports:
This article covered a handful of newly added ML features to Flyte that can simplify building and deploying ML models. Give these features a shot, and let us know what you think of them. You can also take a look at our roadmap to see what’s coming next. Join our Slack in case you have any questions!An ML-powered product comprises several pieces composed into an iterative cycle — data collection, data cleaning, data labeling, model development, model evaluation, model deployment and model observability. Each stage in this process has its own set of requirements and automation possibilities. On a mission to simplify each step of the model development and deployment processes, we released a suite of ML features.