Skip to content

Commit

Permalink
refactor(examples) Update quickstart-jax example (#4121)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Oct 23, 2024
1 parent 7c5a207 commit 76e1c28
Show file tree
Hide file tree
Showing 11 changed files with 338 additions and 222 deletions.
90 changes: 36 additions & 54 deletions examples/quickstart-jax/README.md
Original file line number Diff line number Diff line change
@@ -1,85 +1,67 @@
---
tags: [quickstart, linear regression]
dataset: [Synthetic]
framework: [JAX]
framework: [JAX, FLAX]
---

# JAX: From Centralized To Federated
# Federated Learning with JAX and Flower (Quickstart Example)

This example demonstrates how an already existing centralized JAX-based machine learning project can be federated with Flower.
This introductory example to Flower uses JAX, but deep knowledge of JAX is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [FLAX](https://flax.readthedocs.io/en/latest/index.html) to define and train a small CNN model. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MINST dataset.

This introductory example for Flower uses JAX, but you're not required to be a JAX expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on an existing JAX project.
## Set up the project

## Project Setup
### Clone the project

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:
Start by cloning the example project:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-jax . && rm -rf flower && cd quickstart-jax
git clone --depth=1 https://github.com/adap/flower.git _tmp \
&& mv _tmp/examples/quickstart-jax . \
&& rm -rf _tmp \
&& cd quickstart-jax
```

This will create a new directory called `quickstart-jax`, containing the following files:
This will create a new directory called `quickstart-jax` with the following structure:

```shell
-- pyproject.toml
-- requirements.txt
-- jax_training.py
-- client.py
-- server.py
-- README.md
quickstart-jax
├── jaxexample
│ ├── __init__.py
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ └── task.py # Defines your model, training and data loading
├── pyproject.toml # Project metadata like dependencies and configs
└── README.md
```

### Installing Dependencies
### Install dependencies and project

Project dependencies (such as `jax` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.
Install the dependencies defined in `pyproject.toml` as well as the `jaxexample` package.

#### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

#### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
pip install -r requirements.txt
```bash
pip install -e .
```

## Run JAX Federated
## Run the project

This JAX example is based on the [Linear Regression with JAX](https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html) tutorial and uses a sklearn dataset (generating a random dataset for a regression problem). Feel free to consult the tutorial if you want to get a better understanding of JAX. If you play around with the dataset, please keep in mind that the data samples are generated randomly depending on the settings being done while calling the dataset function. Please checkout out the [scikit-learn tutorial for further information](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html). The file `jax_training.py` contains all the steps that are described in the tutorial. It loads the train and test dataset and a linear regression model, trains the model with the training set, and evaluates the trained model on the test set.
You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine.

The only things we need are a simple Flower server (in `server.py`) and a Flower client (in `client.py`). The Flower client basically takes model and training code tells Flower how to call it.
### Run with the Simulation Engine

Start the server in a terminal as follows:

```shell
python3 server.py
```bash
flwr run .
```

Now that the server is running and waiting for clients, we can start two clients that will participate in the federated learning process. To do so simply open two more terminal windows and run the following commands.
You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example:

Start client 1 in the first terminal:

```shell
python3 client.py
```bash
flwr run . --run-config "num-server-rounds=5 batch-size=32"
```

Start client 2 in the second terminal:
> \[!TIP\]
> For a more detailed walk-through check our [quickstart JAX tutorial](https://flower.ai/docs/framework/tutorial-quickstart-jax.html)
```shell
python3 client.py
```
### Run with the Deployment Engine

You are now training a JAX-based linear regression model, federated across two clients. The setup is of course simplified since both clients hold a similar dataset, but you can now continue with your own explorations. How about changing from a linear regression to a more sophisticated model? How about adding more clients?
> \[!NOTE\]
> An update to this example will show how to run this Flower application with the Deployment Engine and TLS certificates, or with Docker.
54 changes: 0 additions & 54 deletions examples/quickstart-jax/client.py

This file was deleted.

74 changes: 0 additions & 74 deletions examples/quickstart-jax/jax_training.py

This file was deleted.

1 change: 1 addition & 0 deletions examples/quickstart-jax/jaxexample/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""jaxexample: A Flower / JAX app."""
66 changes: 66 additions & 0 deletions examples/quickstart-jax/jaxexample/client_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""jaxexample: A Flower / JAX app."""

import numpy as np
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context

from jaxexample.task import (
apply_model,
create_train_state,
get_params,
load_data,
set_params,
train,
)


# Define Flower Client and client_fn
class FlowerClient(NumPyClient):
def __init__(self, train_state, trainset, testset):
self.train_state = train_state
self.trainset, self.testset = trainset, testset

def fit(self, parameters, config):
self.train_state = set_params(self.train_state, parameters)
self.train_state, loss, acc = train(self.train_state, self.trainset)
params = get_params(self.train_state.params)
return (
params,
len(self.trainset),
{"train_acc": float(acc), "train_loss": float(loss)},
)

def evaluate(self, parameters, config):
self.train_state = set_params(self.train_state, parameters)

losses = []
accs = []
for batch in self.testset:
_, loss, accuracy = apply_model(
self.train_state, batch["image"], batch["label"]
)
losses.append(float(loss))
accs.append(float(accuracy))

return np.mean(losses), len(self.testset), {"accuracy": np.mean(accs)}


def client_fn(context: Context):

num_partitions = context.node_config["num-partitions"]
partition_id = context.node_config["partition-id"]
batch_size = context.run_config["batch-size"]
trainset, testset = load_data(partition_id, num_partitions, batch_size)

# Create train state object (model + optimizer)
lr = context.run_config["learning-rate"]
train_state = create_train_state(lr)

# Return Client instance
return FlowerClient(train_state, trainset, testset).to_client()


# Flower ClientApp
app = ClientApp(
client_fn,
)
47 changes: 47 additions & 0 deletions examples/quickstart-jax/jaxexample/server_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""jaxexample: A Flower / JAX app."""

from typing import List, Tuple

from flwr.common import Context, Metrics, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from jax import random

from jaxexample.task import create_model, get_params


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}


def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]

# Initialize global model
rng = random.PRNGKey(0)
rng, _ = random.split(rng)
_, model_params = create_model(rng)
params = get_params(model_params)
initial_parameters = ndarrays_to_parameters(params)

# Define strategy
strategy = FedAvg(
fraction_fit=0.4,
fraction_evaluate=0.5,
evaluate_metrics_aggregation_fn=weighted_average,
initial_parameters=initial_parameters,
)
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)


# Create ServerApp
app = ServerApp(server_fn=server_fn)
Loading

0 comments on commit 76e1c28

Please sign in to comment.