Notes on throughput management for a TensorFlow image classification app
AI based image analysis techniques have come a long way in recent years. While image classification models can be trained from scratch it is often more practical to retrain an already trained model. This is called transfer learning. Read more about it in the link below. In this post, we will be looking at using a retrained TensorFlow model to optimize data processing throughput using Go.
Retraining an Image Classifier | TensorFlow Hub
Image classification models have millions of parameters. Training them from scratch requires a lot of labeled training…
TensorFlow allows exporting a trained model to be integrated in a variety of programming languages via it’s C API. The workflow discussed here is as follows:
- Retrain an image classifier in python and export the TensorFlow model as a language agnostic protocol buffer file
- Integrate the model in Go using TensorFlow C library/api
- Run application to read/decode images and feed them to the model to fetch classification labels as output
We have several ways in which we could go about setting up data processing within our application. In particular, the data/model parallelism approaches could look as follows:
- Read data and execute model on one image at a time. (sequential, no parallelism)
- Read data and execute model on one batch of images at a time (data parallelism)
- Read a batch of images concurrently and dispatch compute cycle on that batch concurrently as well. This way disk I/O is not waiting for compute to finish before reading next batch of images. (data and model parallelism)
It is interesting to compare the performance of these modes of execution via system monitoring tools. The baseline prior to running application is shown below.
Sequential processing: 1 Image at a time (no parallelism)
Running one image at a time without any data or model parallelism leaves a lot of system resources under utilized. As seen in chart above, only one or two CPU cores are stressed at a time. Consequently, this is the slowest throughput taking roughly 4 minutes to process 600 images.
Data parallelism: 1 batch at a time
When 100 images are read and decoded concurrently and then fed to the TensorFlow model sequentially, we see that we are able to process all 600 images in roughly 7 seconds. This is a huge improvement from sequential processing. Consequently, the CPU profile is a short burst of activity seen in the chart below between horizontal timestamps of 20 to 30 seconds. However, CPU cores are still not saturating and there is room for further gains in throughput.
Finally, if we read images concurrently, pack them in a batch and spawn a TensorFlow job concurrently as well, we are essentially doing everything concurrently. In other words, disk I/O for reading images is no longer waiting for TensorFlow compute to finish before reading next set of images. This further improves the overall throughput and we are able to finish processing all 600 images in about 4 seconds. The CPU profile is also narrower and taller, seen around the timestamp mark of 20 seconds in the chart below. We also see that CPU utilization is roughly 80–100% through the compute duration.
Obviously, this is an optimization problem and the throughput depends on the size of input images, amount of compute in TensorFlow model, available system resources and other such parameters. The general takeaway, however, is that we need to saturate available compute resources and it is fairly easy to do so with a combination of Go and TensorFlow. Let’s go over briefly on how the code was structured to run everything concurrently.
Overall code structure
The code is fairly simple and consists of a few functions that do specific task and a driver function that spawns these functions concurrently. The building blocks are as follows:
- A function to gather list of filenames to operate on
- A function to read and decode images
- A function to run TensorFlow model on a batch of images
- A driver function to wrap these individual functions within Go’s concurrency primitives
Below is a high level view of how these functions look like. The
getFilenames function produces a
string. An image is decoded as a
float32 multi dimensional slice and we have a channel
chan byte to communicate the output from TensorFlow classification. The input to TensorFlow classification is a batch of images represented as
The driver function used the library called error group which allows spawning a go routine that returns an error. This makes it very easy to launch a group of such go-routines can capture any error. Furthermore, it is possible to communicate errors to all go-routines via
ctx, cancel := context.WithCancel(context.Background())
readEg, _ := errgroup.WithContext(ctx)
classifyEg, _ := errgroup.WithContext(ctx)
Read more about
Why You Should Use errgroup.WithContext() in Your Golang Server Handlers | FullStory
Have you ever written a Go server handler that could be parallelized, but you left it as a linear sequence of…
Here is a very nice intro to the
So the pattern looks something as follows:
- Lunch decode image via error group and do so for a batch number of images. This way each run of error group launches several read operations concurrently and then waits for each one of those operations to finish. The final output is a batch of decoded, resized and normalized image data as a
- The classification job is then spawned in another error group and the whole processes is repeated all over
- Finally we wait for all error groups to finish
I like the portability of TensorFlow and its use with Go allows us to construct lightweight software with minimal dependencies. Go’s excellent concurrency primitives allows us to better utilize system resources and improve the data processing throughput. I hope you find this useful. Stay tuned for further updates on containerization of TensorFlow based Go binaries.