Federated Learning
Bringing Machine Learning to the edge with Kotlin and Android
Training a machine learning model requires data. The more we have, the better. However, data is not cheap and more importantly, it can contain sensitive and personal information.
Recent developments in privacy in the form of new laws as GDPR and the increase of awareness of users and citizens in the value of their data is generating a need for techniques to enforce more privacy
Though techniques as anonymisation can greatly help with the privacy issue the fact that all the data is being sent to a central location to train the machine learning models is always a motive to be worried about
This project is a proof of concept of how to set up a basic Federated Learning environment using an Android application as the edge device
The Code
If you want to jump straight away into the code, you can find it in the following repos
Android Application
The server is in
Components
The project is divided in three main parts:
- A server, written in Kotlin and using DL4J to generate a model based on the Cifar-10 dataset
- An Android application that uses this model to classify images taken with the camera. Written in Kotlin and using DL4J too
- The Federated Learning setup where the Android application is able of training the model using the local data and the server is able of updating the shared model with the updates coming from the edge
The Model
The model is based on the Cifar-10 dataset, a well-known dataset that allows classification of ten different classes of images

The architecture of the model has been tuned in order to achieve a double purpose:
a) To have a not-so-bad performance
b) To allow it to be loaded and trained in an Android app
The selected architecture is a shallow Convolutional Neural Network with one CNN layer and a Dense layer. This probed to be enough to obtain a decent performance using 50 epochs and 10,000 samples while keeping the size of the model small

(A note on the model size: The focus of this PoC is in Federated Learning. Better models can be trained with more layers and reducing the size of it by applying different techniques as quantify or by using structured or sketched updates. This is for another PoC!)
The code for training the model in the server side is located in the model
module of the PhotoLabellerServer project
Making predictions with the App
The app allows basic classification of the photos the user takes with the camera using the model embedded in the app itself or, when connected to the server, the latest version of the shared model

The app is structured in modules with the app
module containing the Android specific classes and trainer
the Deeplearning4j related classes. The base
module contains the interactors and domain objects
The implementation of the Trainer
, object in charge of doing the predictions and training using DL4J, invokes the predict function to obtain the classification of the image
The Federated Learning setup
Federated Learning turns the update of Machine Learning models upside-down by allowing the devices on the edge to participate in the training.
Instead of sending the data in the client to a centralised location, Federated Learning sends the model to the devices participating in the federation. The model is then re-trained (using Transfer Learning) with the local data
And the data, your data, never leaves the device, let that be your phone, your laptop or your IoT gadget

The server opens a “round of training” during which the clients can send updates to the model to the server.
Client Side. Training on the edge
Our Android app decides when to participate in the training of the shared model. It performs a Transfer Learning
operation using the model it already has or the one in the server if it’s newer. The update done to the model is then sent to the server
Server Side. Averaging and updating the model
Once the round is closed, the server updates the shared model by doing Federated Averaging as seen in the following gist
The server also implements a simple REST API that is used by the clients.
Note that the client, the Android app, is implemented in a way that would require little effort to move to some other Kotlin platform
Notes
Performance
Doing any kind of operations with images in an Android app is always demanding in terms of the computation effort the device must do. Doing training of a model using images increases this effort several times
This means the transfer learning phase done in the Android app is quite short: a couple of epochs with just a few samples. This is the most the app can do before running out of memory! The total number of parameters is around 450k. That is a lot for the memory available to an app
However, other models using other type of data run smoothly. A previous version of the Federated Learning setup was using a diabetes dataset with just a few features. This could be done with more epochs (I actually didn’t find the limit as I was obtaining the desire performance before reaching an OOM) and more data points
The app is properly architected for you to try some other models and datasets. Feel free to reuse it for your research and let me know if I can help
Thanks for reading!