PyTorch Mobile: Image Classification on iOS

PyTorch joins the mobile ML party alongside Core ML and TFLite

Anupam Chugh
7 min readOct 29, 2019

PyTorch is one of the most sought-after deep learning frameworks. It has its own advantages over the other widely used frameworks like TensorFlow (here’s a great comparison of these two frameworks). Facebook recently released PyTorch 1.3 and plugged the missing piece of the pie in their framework—which is, mobile support for Android and iOS.

Up until now, the two most used mobile machine learning frameworks were Apple’s Core ML and Google’s TensorFlow Lite. So PyTorch’s entry into the mobile domain should be an interesting one. It’ll get a tough challenge from Core ML, but considering the cross-platform support, PyTorch will likely carve out its own space.

PyTorch has already created its own niche thanks to its pythonic ways of building models and its easier learning curve. Moreover, it allows developers to build dynamic computational graphs and makes model debugging even easier, thanks to the Python debugging tools that are readily available.

What’s New In PyTorch 1.3

  • PyTorch Mobile— From Python to providing mobile support across both platforms, this experimental feature is exciting for mobile developers.
  • Named Tensors — By adding associative names for tensor dimensions, PyTorch aims to help developers write readable and maintainable code.
  • Quantization Support — Space is always a blocker on mobile devices. By using FBGEMM and QNNPACK, PyTorch strives to support quantization, for x86 and ARM CPUs.

Plan Of Action

Running a PyTorch model on iOS Devices requires ticking the following checkboxes:

  • Converting the model to TorchScript format (.pt) using a Python script.
  • Integrating the PyTorch C++ pod framework to our Xcode project.
  • Using Objective C++ as the bridge header file to run PyTorch inferences from the Swift codebase.

In the next few sections, we’ll be running image classification on images captured from the camera or selected from the photos library using a PyTorch model on iOS Devices.

For Android developers looking to implement PyTorch, just hop on over to this link for a look at working with image classification on that platform.

Setup

To start off, you need a Mac, Xcode 11, and a device with iOS 12 or above as a deployment target.

You need to install torchvision from pip using the command line, as shown below:

pip install torchvision

Converting Our Model

We need to use the TorchScript format in order to convert a model to PyTorch Mobile. Any CNN or ResNet model can be ported to PyTorch Mobile. Since I love my iPhone, I’ll be using a MobileNetV2 model that’s highly optimized and provides great accuracy.

Running python convert_pytorch_model.py will generate your “mobilenet-v2.pt” model. It’s ready to ship into Xcode!

Building Our PyTorch Mobile iOS Application

In this section, we’ll be creating our iOS Application that runs the PyTorch model on images taken from the camera or photos library and displays the label with the highest confidence on the screen.

Installing Dependencies

Create a new Xcode Project. A Single View iOS Application template would do the job. Once that’s done, install the following pod dependency in your Podfile (you need to do a pod init in your project directory to create a Podfile):

# platform :ios, '12.0'
target 'PyTorchMobileiOS' do
pod 'LibTorch', '~> 1.3.0'
end

Doing this pod install creates a xcworkspace file in your project. Close all current Xcode sessions and relaunch using that file.

Now you should be able to add the framework from the Project Navigator -> General Tab, as shown below:

Screengrab from my Xcode project navigator General Tab.

Adding Labels For Our PyTorch Mobile Model

Our labels file is a text file with 1000s of words. You can view the contents of the file from the link below:

In the next sections, we’ll be using both Objective-C and Swift in our codebase. Objective-C is required to communicate with C++. Here’s an illustration of how the codes interact with each other:

Objective-C Code

Currently, PyTorch Mobile is a C++ library. Hence, it cannot directly communicate with Swift. So we need to use Objective-C to communicate with the C++ code.

Swift and Objective-C languages can communicate with each other using a bridging header file.

Creating an Objective-C Bridging Header File

A bridging header file basically exposes Objective-C header files to Swift.

In order to generate a bridging header file, simply drag and drop or create a new Objective-C header file in your project. Xcode automatically prompts you to create a bridging header file. The file name format is <ProjectName>-Bridging-Header.h

We’ve added #import “TorchModule.h” in our bridging header file. The code for the TorchModule.h file is given below:

In the above code, NS_SWIFT_NAME and NS_UNAVAILABLE are macros.

  • NS_SWIFT_NAME allows us to provide full Swift names for their Objective-C counterparts. So now, we can execute the Objective-C methods using the Swift functions assigned in the macros.
  • NS_UNAVAILABLE macro basically tells the compiler not to export that class, function or instance to Swift.

Such macros improve the Swift <> Objective-C interoperability tremendously.

Loading Our Model And Running Inference in Objective-C++

The above Objective-C++ code is where we’re loading the model after quantization, and then we’ll run the inference on the image. The TorchModule instance is initialized from the Swift codebase that we’ll see shortly.

The initWithFileAtPath method is called from the Swift counterpart method. It passes the location of the model file.

The predictImage gets the CVPixelBuffer of the image and converts it into an input tensor of the required shape. We then iterate through the scores of each label and return the index having the highest score back to Swift as the predicted output.

Note: Setting AutoGradMode to false indicates we wish to run inference with our model only (no training).

Now that the Objective-C part is done, let’s jump back to the Swift code—the easier one!

Swift Code

Let’s first build the UI programmatically. But first, embed the ViewController inside a Navigation Controller.

Building the UI Programmatically

There’s nothing fancy in the UI here. We just need to set up a UIButton, UIImageView, and UILabel in our ViewController class.

Configuring Button Action

Now, we need to set an action when the button is tapped. The idea is to show an ImagePicker that allows the user to pick images from the camera or gallery. For this, we’ll create a UIAlertController that has the actionSheet style.

You need to add the Privacy Usage Descriptions in the info.plist for the Camera and Photos Library each, as shown below:

Note: Your ViewController class needs to conform to the UIImagePickerControllerDelegate and UINavigationControllerDelegate protocols in order to receive the image from the camera/photos library.

Now that we’ve set the UI for our application, it’s time to process the image to fit into the model input constraints and subsequently run the classifier. Let’s do that in the next section.

Preprocessing the Image

The following utility extension contains the code for resizing and normalizing the input image.

The image is resized as per the model input size which is (224x244). The normalized function converts the resized image into a Float32 tensor. Now our input is ready to get inferred by the model.

Predicting Image

Now we just need to call the Objective-C prediction function with the processed input image from the Swift code, as shown below:

The predictImage function is where we resize and normalize the image and pass the tensor to the TorchModule Objective-C wrapper class we saw earlier.

In return, we display the predicted label on the screen, as shown in the illustration below:

A screengrab from the PyTorch iOS Application.

What’s Next

Moving forward, GPU support for PyTorch Mobile would be a major boost.

Also, Swift wrapper APIs will be something developers would keenly wait for to avoid bridging header files in their codebases.

That’s it for this one. I hope you enjoyed.

Editor’s Note: Heartbeat is a contributor-driven online publication and community dedicated to providing premier educational resources for data science, machine learning, and deep learning practitioners. We’re committed to supporting and inspiring developers and engineers from all walks of life.

Editorially independent, Heartbeat is sponsored and published by Comet, an MLOps platform that enables data scientists & ML teams to track, compare, explain, & optimize their experiments. We pay our contributors, and we don’t sell ads.

If you’d like to contribute, head on over to our call for contributors. You can also sign up to receive our weekly newsletters (Deep Learning Weekly and the Comet Newsletter), join us on Slack, and follow Comet on Twitter and LinkedIn for resources, events, and much more that will help you build better ML models, faster.

--

--

Anupam Chugh

iOS and Android Developer. Online Writer. Editor @BttrProgramming. Marketer. Wannabe Filmmaker, and a Funny Human bot!