https://ift.tt/2YkY1yu ToF Ten or fewer takeaways on using FedJAX for federated learning research (Image by author) The ♦ ️symbol on top...
ToF
Ten or fewer takeaways on using FedJAX for federated learning research
The idea behind the Ten or Fewer (ToF for short and can be pronounced like “tuff”) blog is to efficiently — and effectively — share main concepts, takeaways, opinions, or the like around a specific topic. Specifically, I’ll aim to keep it short by having only ten or fewer relevant things to say.
Affiliation disclosure:
I lead the Machine Learning Science team at integrate.ai, where we are focused on making federated learning (FL) more accessible. I am not affiliated with FedJAX, nor have I in the past contributed to this open-source project.
Accompanying code: https://github.com/integrateai/openlab/tree/main/fedjax_tutorial
(0) (An FL Primer that you can feel free to skip and just refer back to whenever necessary.) FL allows for the building of machine learning models without first having to bring data to a central location. In a typical FL setting, there are clients who each have their own dataset housed in their own individual nodes and a central server that “leverages” these non-centralized datasets in order to build a global ML model (i.e., commonly referred to as server model), which is then shared by all of the clients. Importantly, the datasets never come together to form a massive training dataset, but rather are kept on the respective client nodes and training is facilitated by only communicating model parameters (or their updates) between each of the client nodes and the central server. The basic steps of the FL framework are provided below (adapted from Ro, Suresh, Wu (2021)).
Note that that FL training is done in multiple rounds, where in each round the following steps are taken:
1. Server selects a set of clients to participate in the current round;
2. Server communicates the current model parameters of the global, server model to the participating clients;
3. Each client creates a local version of the server model using the snapshot provided in 2. and updates this by training it on its local data.
4. After performing it's individual client update, each client then communicates back the model parameters updates to the server.
5. The server takes the collection of client model parameter updates and performs an aggregation on them (e.g., via averaging).
6. The server updates the server model via the aggregated model parameter update computed in the previous step.
(1) FedJAX is a library for conducting FL research. FedJAX aims to shorten the cycle time of conducting FL research (e.g., running experiments for FL) as well as making it more standardized. It does so in two ways. First, since it is quite unnecessary to perform the communication between server and client when conducting algorithmic research of FL models, FedJAX simulations essentially skip steps 2. and 4. listed above. That’s because with FedJAX, the multi-client FL setting is only fictitiously created on a single machine, and as such, FedJAX should not be used for building and deploying an actual FL solution. Second, FedJAX provides convenient interfaces for the core elements within an FL framework, such as: FederatedData, ClientDataset, Models, ClientSampler, and FederatedAlgorithm that call functions for handling both client model updates as well as server model updates.
(2) Getting going on FedJAX is not bad at all. I found it relatively straightforward to get ramped up on FedJAX to the level where I can carry out vanilla, horizontal FL on a dataset. The docs as well as the official code, despite still being in their early stages at the time of writing this blog, are quite helpful. Moreover, as typically is the case with learning any new library, there is a bit of an associated learning curve, so it took me ~1 weeks time to be able to learn enough of the library to complete two examples: i) an example using the EMNIST dataset that was split into 3400 clients; and, ii) an example using the CIFAR-10 dataset that was split into 100 clients. For the EMNIST example, the dataset and the CNN model that I used were both supplied by the FedJAX library. In contrast, for the CIFAR-10 example, I created a custom dataset as well as a custom CNN model using the helper functions provided by FedJAX. The runtimes as well as the final server model accuracies are shared below, but you can also find the code to those two examples in order to execute them for yourselves here!
(3) Prior familiarity with JAX is not necessary to get started, but would be important to know for more serious research using FedJAX. I had not worked with JAX before sinking my teeth into FedJAX, yet I didn’t find that to be a major obstacle in learning about FedJAX interfaces and getting vanilla FL up and running on those two aforementioned examples. However, when looking more deeply in the source code, I definitely started to sense that I could benefit from getting more familiarized with JAX. And so, I feel that if one is wanting to do more serious research in FL using FedJAX, then it would certainly be worth one’s time to spend getting a good level of familiarity with JAX, and perhaps as well, other JAX-based deep learning libraries such as Haiku.
The next 3 things to say cover three of the many core classes offered in the FedJAX (v=0.0.7) library that enable standardized FL simulations.
(4) FederatedData is a data structure that contains meta info on the client datasets and convenient methods for performing tasks under the simulated FL training. Furthermore, since FederatedData can be thought of as providing a mapping between client_id’s and ClientDataset’s, accessing the array-like data structures representing the actual observations takes a bit more work—somewhat reflective of an actual FL setting, where accessing data from a client node is challenging, or maybe even impossible without interventions due to privacy considerations. Finally, FedJAX conveniently provides functionality for creating a small custom FederatedData from an numpy.array, which is what I used to create the CIFAR-10 dataset in my second example.
(5) Model is a FedJAX class for specifying the server ML model structure and loss function. It’s important to note that theModel class is stateless meaning that there is no params attribute that gets updated during training after each epoch/round. Instead, Model supplies its information to jax.grad() that automatically computes the gradients of the associated objective function and passes these to a chosen Optimizer, which then finally returns a set of params that fits perfectly into the Model’s structure. Generating predictions on examples can be done by calling apply_for_eval() (or apply_for_train() if a random key is needed to inject randomness for something like dropout). FedJAX also provides helper functions for creating aModel from other JAX-based, deep learning libraries such as Haiku or jax.example_libraries.stax.
(6) FederatedAlgorithm is a class that facilitates the FL training (i.e., client and server update, or steps 3. & 5. listed at the top of this blog) of a Model on a FederatedData. Building a custom FederatedAlgorithm would require defining functions for performing the updates of client models as well as aggregation and server model update, which should be encapsulated in its apply() method. There is only one other method belonging to this class, namely, init(), which initializes the state of the server model. FedJAX also aims to make client updates more efficient by leveraging available accelerators, and so, you’re required to define the client update function in a specific manner (i.e., through three functions client_init, client_step, and client_final).
(7) Overall, FedJAX has got a lot of promise. The premise of speeding up and standardizing FL research is a good one. The interfaces that FedJAX offer at time of writing are useful and convenient, and represent a good basis for doing more serious research in FL. In the near future, I would hope to see more datasets and benchmarks contributed to this project in order to improve repeatability and level-setting across the field. Additionally, even more modules and examples that go beyond vanilla FL would make this project more relevant, such as methods considering things like differential privacy, non-iid cases across the client datasets, vertical FL, and maybe even non-SGD type of client/server models.
Conducting Federated Learning Research with FedJAX was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
from Towards Data Science - Medium https://ift.tt/3q8AxIs
via RiYo Analytics
ليست هناك تعليقات