pytorch-grpc-serving: Serving PyTorch Models for Inference as gRPC API

Oct 7, 2022

Note: This blog post is part of my ongoing work on experiments with model training, deployment and monitoring repository bitbeast. If you liked this blog post, please upvote on Hacker News.

Source Code: GitHub

A framework that started with a goal to help researchers build Deep Learning models has today reached the pinnacle. Congratulations to the PyTorch Governance Team, Meta, and all contributors on joining the Linux Foundation. PyTorch has gone from Research to Production on all scales. It is easy not only to use or finetune a SOTA model in PyTorch but also to deploy it on your large-scale systems or even on your mobile phones. Here are some of the go-to solutions for serving PyTorch models for inference.

AWS TorchServe

Launched in 2020, TorchServe has been a go-to solution for deploying PyTorch models. One can even compare it to TFX Serving for Tensorflow. TorchServe comes with all the features of an enterprise-level serving framework built in collaboration with AWS. It supports REST and gRPC APIs for getting predictions from your PyTorch model of choice. One has to write a Python handler for defining several stages likes preprocess, inference, and postprocess. TorchServe gives all types of controls for managing the server, model processes, worker groups, etc. It even supports Workflows when one wants DAG-like processing during inference.

Beautiful PyTorch Ecosystem

Thanks to the amazing PyTorch Ecosystem, plenty of tools and examples are available for serving the models. Here are some popular tools/examples when it comes to Model Deployment for PyTorch:

Hidden gems in MultiPy or torch::deploy

A work-in-progress project for now, but multiPy is a promising experiment from the PyTorch team for using python for model inference. It’s an alternate path from TorchScript - which is the de facto standard for deployment. Their arXiv paper explains how this multi-interpreter approach ensures inference is scalable.

Note: This package/API is not stable and one needs to build PyTorch from source to use this.


gRPC for Inference

RPC / gRPC provides lots of benefits like bi-directional streaming, asynchronous communication between services and low latency especially when it comes to large scale microservices. This blog from AWS Machine Learning, talks about how to reduce inference latency using gRPC. The aim is to use gRPC in the easiest way to start serving your PyTorch models with simple requirements.

Utilizing TorchScript to its full potential

If you need a refresher on TorchScript, check out Intro a TorchScript tutorial.

Let us start by defining a TorchScript module, which will load a pre-trained model, run inference, and get top K classes with probabilities. It will also do the additional step of mapping classes to labels.

class YourTorchScriptModule(torch.nn.Module):
    """Your TorchScript Module"""
    def __init__(self) -> None:
        super(YourTorchScriptModule, self).__init__()

        # load the pretrained imagenet model
        self.model = model(weights=<your choice of weight>)
        self.model.eval()

        # map imagenet classes to labels
        self.categories = [s.strip() for s in \
            open('imagenet_classes.txt', encoding='utf-8').readlines()]

    def forward(self, img_tensor, topk=5):
        """Forward Pass"""
        # do a forward pass and get class probabilities
        output = self.model(img_tensor)
        probabilities = torch.nn.functional.softmax(output, dim=1)
        
        # get topK items
        top_prob, top_class = torch.topk(probabilities, topk)
        
        # return {"class": "probability score"} e.g. {"pizza": 0.44922224}
        return dict({
            self.categories[top_class[0][idx].item()]: top_prob[0][idx].item() \
            for idx in range(0, int(topk))
        })

Refer torchscript.py from the repository as an example of scripting the ResNet50 Quantized Model. With the latest TorchVision, it is easy to load transforms for a pretrained model using its weight.

Writing the gRPC Service

gRPC uses Protocol Buffers an open source mechanism for serializing structured data. If you are new to gRPC, check out gRPC Quickstart Tutorial.

Let us define the protobuf definition for our Inference gRPC Service.

syntax = "proto3";
import "google/protobuf/empty.proto";
message PredictionRequest {
    bytes input = 1;
}
message PredictionResponse {
    bytes prediction = 1;
}
service Inference {
    rpc Prediction(PredictionRequest) returns (PredictionResponse) {}
}

This can be used to generate the server and client stubs using the following command:

python -m grpc_tools.protoc -I<file_folder> --python_out=. --grpc_python_out=. <file_location>

Refer server.py and client.py as an example for writing the gRPC server and client respectively. Now there are several ways to optimize a gRPC server for maximum throughput. But that is beyond the scope of this tutorial.

On running the reference client, one can see a output like:

prediction: {"pizza": 0.44644981622695923, "potpie": 0.009865873493254185, "hot pot": 0.007180684246122837, "consomme": 0.005226321052759886, "spatula": 0.0047011710703372955}


Deployment

Reducing Docker Image Size

Refer the super simple Dockerfile which can be used to dockerize this gRPC service and deploy it.

One of the challenges while shipping fully-baked docker images for inferencing in Python is its huge image size. This makes it bulky and hard for CPU deployments. For CPU-specific inferencing, one can build the docker image using custom wheels of PyTorch and TorchVision. Refer requirements.txt for links to CPU-specific versions of PyTorch and TorchVision.

Bye Heroku, Hello Fly.io

Since Heroku declared sunset for its free dynos, it was time for the internet to settle and make peace with a new provider. Luckily, I moved on from Heroku to Fly.io almost 2 years ago. Fly is fast, minimal, and easy to use with an addon of HTTP/2 which makes it possible to run a gRPC server. One can find a reference fly.toml config for a Fly app. A sample gRPC service running Quantized ResNet50 pretrained on ImageNet is running at pytorch-serving.fly.dev:8000 🥳

If you liked the idea of pytorch-grpc-serving, go check it out on GitHub. If you have an idea or a suggestion for improvement, feel free to contribute via Issues/Pull Requests!