Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flwr run to command line interface #3049

Merged
merged 25 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ cryptography = "^42.0.4"
pycryptodome = "^3.18.0"
iterators = "^0.0.2"
typer = { version = "^0.9.0", extras=["all"] }
tomli = "^2.0.1"
# Optional dependencies (VCE)
ray = { version = "==2.6.3", optional = true }
pydantic = { version = "<2.0.0", optional = true }
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .example import example
from .new import new
from .run import run

app = typer.Typer(
help=typer.style(
Expand All @@ -30,6 +31,7 @@

app.command()(new)
app.command()(example)
app.command()(run)

if __name__ == "__main__":
app()
19 changes: 10 additions & 9 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def new(
if value == framework_value
]
framework_str = selected_value[0]

framework_str = framework_str.lower()

# Set project directory path
cwd = os.getcwd()
Expand All @@ -106,18 +108,17 @@ def new(
"README.md": {
"template": "app/README.md.tpl",
},
"requirements.txt": {
"template": f"app/requirements.{framework_str.lower()}.txt.tpl"
},
"requirements.txt": {"template": f"app/requirements.{framework_str}.txt.tpl"},
"flower.toml": {"template": "app/flower.toml.tpl"},
f"{pnl}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{pnl}/server.py": {
"template": f"app/code/server.{framework_str.lower()}.py.tpl"
},
f"{pnl}/client.py": {
"template": f"app/code/client.{framework_str.lower()}.py.tpl"
},
f"{pnl}/server.py": {"template": f"app/code/server.{framework_str}.py.tpl"},
f"{pnl}/client.py": {"template": f"app/code/client.{framework_str}.py.tpl"},
}

# In case framework is MlFramework.PYTORCH generate additionally the utils.py file
if framework_str == MlFramework.PYTORCH.value.lower():
files[f"{pnl}/utils.py"] = {"template": f"app/code/utils.{framework_str}.py.tpl"}

context = {"project_name": project_name}

for file_path, value in files.items():
Expand Down
128 changes: 128 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
Original file line number Diff line number Diff line change
@@ -1 +1,129 @@
"""$project_name: A Flower / PyTorch app."""

from collections import OrderedDict
from typing import Dict, Tuple, List

import torch
from torch.utils.data import DataLoader

import flwr as fl
from flwr.common import Metrics
from flwr.common.typing import Scalar

from flwr_datasets import FederatedDataset

from utils import Net, train, test, apply_transforms

NUM_CLIENTS = 100
NUM_ROUNDS = 10


# Flower client, adapted from Pytorch quickstart example
class FlowerClient(fl.client.NumPyClient):
def __init__(self, trainset, valset):
self.trainset = trainset
self.valset = valset

# Instantiate model
self.model = Net()

# Determine device
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model.to(self.device) # send model to device

def get_parameters(self, config):
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

def fit(self, parameters, config):
set_params(self.model, parameters)

# Read from config
batch, epochs = config["batch_size"], config["epochs"]

# Construct dataloader
trainloader = DataLoader(self.trainset, batch_size=batch, shuffle=True)

# Define optimizer
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
# Train
train(self.model, trainloader, optimizer, epochs=epochs, device=self.device)

# Return local model and statistics
return self.get_parameters({}), len(trainloader.dataset), {}

def evaluate(self, parameters, config):
set_params(self.model, parameters)

# Construct dataloader
valloader = DataLoader(self.valset, batch_size=64)

# Evaluate
loss, accuracy = test(self.model, valloader, device=self.device)

# Return statistics
return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)}


def get_client_fn(dataset: FederatedDataset):
"""Return a function to construct a client.

The VirtualClientEngine will execute this function whenever a client is sampled by
the strategy to participate.
"""

def client_fn(cid: str) -> fl.client.Client:
"""Construct a FlowerClient with its own dataset partition."""

# Let's get the partition corresponding to the i-th client
client_dataset = dataset.load_partition(int(cid), "train")

# Now let's split it into train (90%) and validation (10%)
client_dataset_splits = client_dataset.train_test_split(test_size=0.1)

trainset = client_dataset_splits["train"]
valset = client_dataset_splits["test"]

# Now we apply the transform to each batch.
trainset = trainset.with_transform(apply_transforms)
valset = valset.with_transform(apply_transforms)

# Create and return client
return FlowerClient(trainset, valset).to_client()

return client_fn


def fit_config(server_round: int) -> Dict[str, Scalar]:
"""Return a configuration with static batch size and (local) epochs."""
config = {
"epochs": 1, # Number of local epochs done by clients
"batch_size": 32, # Batch size to use by clients during fit()
}
return config


def set_params(model: torch.nn.ModuleList, params: List[fl.common.NDArrays]):
"""Set model weights from a list of NumPy ndarrays."""
params_dict = zip(model.state_dict().keys(), params)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
"""Aggregation function for (federated) evaluation metrics, i.e. those returned by
the client's evaluate() method."""
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}


# Download MNIST dataset and partition it
mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS})

# ClientApp for Flower-Next
app = fl.client.ClientApp(
client_fn=get_client_fn(mnist_fds),
)
129 changes: 129 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl
Original file line number Diff line number Diff line change
@@ -1 +1,130 @@
"""$project_name: A Flower / PyTorch app."""

"""$project_name: A Flower / PyTorch app."""

from collections import OrderedDict
from typing import Dict, Tuple, List

import torch
from torch.utils.data import DataLoader

import flwr as fl
from flwr.common import Metrics
from flwr.common.typing import Scalar

from datasets import Dataset
from datasets.utils.logging import disable_progress_bar
from flwr_datasets import FederatedDataset

from utils import Net, test, apply_transforms

NUM_CLIENTS = 100
NUM_ROUNDS = 10


def get_client_fn(dataset: FederatedDataset):
"""Return a function to construct a client.

The VirtualClientEngine will execute this function whenever a client is sampled by
the strategy to participate.
"""

def client_fn(cid: str) -> fl.client.Client:
"""Construct a FlowerClient with its own dataset partition."""

# Let's get the partition corresponding to the i-th client
client_dataset = dataset.load_partition(int(cid), "train")

# Now let's split it into train (90%) and validation (10%)
client_dataset_splits = client_dataset.train_test_split(test_size=0.1)

trainset = client_dataset_splits["train"]
valset = client_dataset_splits["test"]

# Now we apply the transform to each batch.
trainset = trainset.with_transform(apply_transforms)
valset = valset.with_transform(apply_transforms)

# Create and return client
return FlowerClient(trainset, valset).to_client()

return client_fn


def fit_config(server_round: int) -> Dict[str, Scalar]:
"""Return a configuration with static batch size and (local) epochs."""
config = {
"epochs": 1, # Number of local epochs done by clients
"batch_size": 32, # Batch size to use by clients during fit()
}
return config


def set_params(model: torch.nn.ModuleList, params: List[fl.common.NDArrays]):
"""Set model weights from a list of NumPy ndarrays."""
params_dict = zip(model.state_dict().keys(), params)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
"""Aggregation function for (federated) evaluation metrics, i.e. those returned by
the client's evaluate() method."""
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}


def get_evaluate_fn(
centralized_testset: Dataset,
):
"""Return an evaluation function for centralized evaluation."""

def evaluate(
server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar]
):
"""Use the entire CIFAR-10 test set for evaluation."""

# Determine device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Net()
set_params(model, parameters)
model.to(device)

# Apply transform to dataset
testset = centralized_testset.with_transform(apply_transforms)

# Disable tqdm for dataset preprocessing
disable_progress_bar()

testloader = DataLoader(testset, batch_size=50)
loss, accuracy = test(model, testloader, device=device)

return loss, {"accuracy": accuracy}

return evaluate


# Download MNIST dataset and partition it
mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS})
centralized_testset = mnist_fds.load_full("test")

# Configure the strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.1, # Sample 10% of available clients for training
fraction_evaluate=0.05, # Sample 5% of available clients for evaluation
min_available_clients=10,
on_fit_config_fn=fit_config,
evaluate_metrics_aggregation_fn=weighted_average, # Aggregate federated metrics
evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function
)

# ServerApp for Flower-Next
server = fl.server.ServerApp(
config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
strategy=strategy,
)
64 changes: 64 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/utils.pytorch.py.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.transforms import ToTensor, Normalize, Compose


# transformation to convert images to tensors and apply normalization
def apply_transforms(batch):
transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
batch["image"] = [transforms(img) for img in batch["image"]]
return batch


# Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')
class Net(nn.Module):
def __init__(self, num_classes: int = 10) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


# borrowed from Pytorch quickstart example
def train(net, trainloader, optim, epochs, device: str):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
net.train()
for _ in range(epochs):
for batch in trainloader:
images, labels = batch["image"].to(device), batch["label"].to(device)
optim.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
optim.step()


# borrowed from Pytorch quickstart example
def test(net, testloader, device: str):
"""Validate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
net.eval()
with torch.no_grad():
for data in testloader:
images, labels = data["image"].to(device), data["label"].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
accuracy = correct / len(testloader.dataset)
return loss, accuracy
8 changes: 7 additions & 1 deletion src/py/flwr/cli/new/templates/app/flower.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ description = ""
license = "Apache-2.0"
authors = ["The Flower Authors <[email protected]>"]

[components]
[flower.components]
serverapp = "$project_name.server:app"
clientapp = "$project_name.client:app"

[flower.engine]
name = "simulation" # optional

[flower.engine.simulation.super-node]
count = 10 # optional
Loading
Loading