Skip to content

Commit

Permalink
start with gans
Browse files Browse the repository at this point in the history
  • Loading branch information
nata1y committed May 10, 2021
1 parent 95a71aa commit a103d67
Show file tree
Hide file tree
Showing 15 changed files with 1,818 additions and 6 deletions.
3 changes: 1 addition & 2 deletions examples/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def create_federator(compute, project, zone, name_template, rank, region, machin
zone=zone,
body=client_config).execute()


def create_client(compute, project, zone, name_template, rank, world_size, host, nic, region, machine_image):
machine_type = f'zones/{zone}/machineTypes/g1-small'
instance_name = name_template.format(rank=rank)
Expand Down Expand Up @@ -205,7 +206,6 @@ def wait_for_operation(compute, project, zone, operation):
# [END wait_for_operation]



if __name__ == "__main__":

parser = argparse.ArgumentParser(description='Create VMs in GCP for Federated Learning')
Expand Down Expand Up @@ -266,4 +266,3 @@ def wait_for_operation(compute, project, zone, operation):
wait_for_operation(compute, project_name, zone_name, operation['name'])

print("""Now login via ssh into the federator VM and start the experiment.""")

7 changes: 5 additions & 2 deletions fltk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
import argparse

import torch.multiprocessing as mp
from fltk.federator import Federator
from fltk.federator_fegan import Federator
from fltk.launch import run_single, run_spawn
from fltk.util.base_config import BareConfig

logging.basicConfig(level=logging.DEBUG)


def add_default_arguments(parser):
parser.add_argument('--world_size', type=str, default=None,
help='Number of entities in the world. This is the number of clients + 1')


def main():
parser = argparse.ArgumentParser(description='Experiment launcher for the Federated Learning Testbed')

Expand Down Expand Up @@ -78,5 +80,6 @@ def main():
else:
run_spawn(cfg)


if __name__ == "__main__":
main()
main()
317 changes: 317 additions & 0 deletions fltk/client_fegan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
import copy
import datetime
import os
import random
import time
from dataclasses import dataclass
from typing import List

import torch
from torch.distributed import rpc
from torch.autograd import Variable
import logging
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from torch.nn import MSELoss

from fltk.schedulers import MinCapableStepLR
from fltk.util.arguments import Arguments
from fltk.util.fed_avg import average_nn_parameters
from fltk.util.weight_init import *
from fltk.util.log import FLLogger
from fltk.nets.md_gan import *

import yaml

from fltk.util.results import EpochData

logging.basicConfig(level=logging.DEBUG)


def _call_method(method, rref, *args, **kwargs):
"""helper for _remote_method()"""
return method(rref.local_value(), *args, **kwargs)


def _remote_method(method, rref, *args, **kwargs):
"""
executes method(*args, **kwargs) on the from the machine that owns rref
very similar to rref.remote().method(*args, **kwargs), but method() doesn't have to be in the remote scope
"""
args = [method, rref] + list(args)
return rpc.rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)


def _remote_method_async(method, rref, *args, **kwargs):
args = [method, rref] + list(args)
return rpc.rpc_async(rref.owner(), _call_method, args=args, kwargs=kwargs)


def average_models(model):
return model


class Client:
counter = 0
finished_init = False
dataset = None
epoch_results: List[EpochData] = []
epoch_counter = 0

def __init__(self, id, log_rref, rank, world_size, config=None):
logging.info(f'Welcome to client {id}')
self.id = id
self.log_rref = log_rref
self.rank = rank
self.world_size = world_size
# self.args = Arguments(logging)
self.args = config
self.args.init_logger(logging)
self.device = self.init_device()
self.loss_function = self.args.get_loss_function()()
self.optimizer_generator = torch.optim.Adam(self.discriminator.parameters(),
lr=self.args.get_learning_rate(),
betas=(self.args.b1(), self.args.b2()))
self.optimizer_discriminator = torch.optim.Adam(self.discriminator.parameters(),
lr=self.args.get_learning_rate(),
betas=(self.args.b1(), self.args.b2()))
self.scheduler = MinCapableStepLR(self.args.get_logger(), self.optimizer,
self.args.get_scheduler_step_size(),
self.args.get_scheduler_gamma(),
self.args.get_min_lr())

def init_device(self):
if self.args.cuda and torch.cuda.is_available():
return torch.device("cuda:0")
else:
return torch.device("cpu")

def ping(self):
return 'pong'

def rpc_test(self):
sleep_time = random.randint(1, 5)
time.sleep(sleep_time)
self.local_log(f'sleep for {sleep_time} seconds')
self.counter += 1
log_line = f'Number of times called: {self.counter}'
self.local_log(log_line)
self.remote_log(log_line)

def remote_log(self, message):
_remote_method_async(FLLogger.log, self.log_rref, self.id, message, time.time())

def local_log(self, message):
logging.info(f'[{self.id}: {time.time()}]: {message}')

def set_configuration(self, config: str):
yaml_config = yaml.safe_load(config)

def init(self):
pass

def init_dataloader(self, ):
self.args.distributed = True
self.args.rank = self.rank
self.args.world_size = self.world_size
# self.dataset = DistCIFAR10Dataset(self.args)
self.dataset = self.args.DistDatasets[self.args.dataset_name](self.args)
self.finished_init = True

self.discriminator.apply(weights_init_normal)
logging.info('Done with init')

def is_ready(self):
return self.finished_init

def load_model_from_file(self, model_file_path):
model_class = self.args.get_net()
default_model_path = os.path.join(self.args.get_default_model_folder_path(), model_class.__name__ + ".model")
return self.load_model_from_file(default_model_path)

def load_default_model(self):
"""
Load a model from default model file.
This is used to ensure consistent default model behavior.
"""
model_class = self.args.get_net()
default_model_path = os.path.join(self.args.get_default_model_folder_path(), model_class.__name__ + ".model")

return self.load_model_from_file(default_model_path)

def load_model_from_file(self, model_file_path):
"""
Load a model from a file.
:param model_file_path: string
"""
model_class = self.args.get_net()
model = model_class()

if os.path.exists(model_file_path):
try:
model.load_state_dict(torch.load(model_file_path))
except:
self.args.get_logger().warning(
"Couldn't load model. Attempting to map CUDA tensors to CPU to solve error.")

model.load_state_dict(torch.load(model_file_path, map_location=torch.device('cpu')))
else:
self.args.get_logger().warning("Could not find model: {}".format(model_file_path))

return model

def get_client_index(self):
"""
Returns the client index.
"""
return self.client_idx

def train(self, epoch, net):
"""
:param epoch: Current epoch #
:type epoch: int
"""
generator, discriminator = net

# save model
if self.args.should_save_model(epoch):
self.save_model(epoch, self.args.get_epoch_save_start_suffix())

running_loss = 0.0
final_running_loss = 0.0
if self.args.distributed:
self.dataset.train_sampler.set_epoch(epoch)

for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 1):
inputs, labels, fake = inputs.to(self.device), labels.to(self.device), torch.zeros(inputs.shape[0.0])

self.optimizer_generator.zero_grad()

noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (inputs.shape[0], 100))))
generated_imgs = generator(noise)
d_generator = discriminator(generated_imgs)
generator_loss = MSELoss(d_generator, labels)
generator_loss.backwards()

self.optimizer_generator.step()

self.optimizer_discriminator.zero_grad()

real_loss = self.loss_function(discriminator(inputs), labels)
fake_loss = self.loss_function(discriminator(generated_imgs.detach()), fake)
discriminator_loss = 0.5 * (real_loss + fake_loss)
discriminator_loss.backward()

self.optimizer_discriminator.step()

# TODO: fix to min-max loss
running_loss += generator_loss.item() + discriminator_loss.items()
if i % self.args.get_log_interval() == 0:
self.args.get_logger().info(
'[%d, %5d] loss: %.3f' % (epoch, i, running_loss / self.args.get_log_interval()))
final_running_loss = running_loss / self.args.get_log_interval()
running_loss = 0.0

self.scheduler.step()

# save model
if self.args.should_save_model(epoch):
self.save_model(epoch, self.args.get_epoch_save_end_suffix())

return final_running_loss, (generator, discriminator)

def test(self):
self.net.eval()

correct = 0
total = 0
targets_ = []
pred_ = []
loss = 0.0
with torch.no_grad():
for (images, labels) in self.dataset.get_test_loader():
images, labels = images.to(self.device), labels.to(self.device)

outputs = self.net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

targets_.extend(labels.cpu().view_as(predicted).numpy())
pred_.extend(predicted.cpu().numpy())

loss += self.loss_function(outputs, labels).item()

accuracy = 100 * correct / total
confusion_mat = confusion_matrix(targets_, pred_)

class_precision = self.calculate_class_precision(confusion_mat)
class_recall = self.calculate_class_recall(confusion_mat)

self.args.get_logger().debug('Test set: Accuracy: {}/{} ({:.0f}%)'.format(correct, total, accuracy))
self.args.get_logger().debug('Test set: Loss: {}'.format(loss))
self.args.get_logger().debug("Classification Report:\n" + classification_report(targets_, pred_))
self.args.get_logger().debug("Confusion Matrix:\n" + str(confusion_mat))
self.args.get_logger().debug("Class precision: {}".format(str(class_precision)))
self.args.get_logger().debug("Class recall: {}".format(str(class_recall)))

return accuracy, loss, class_precision, class_recall

def run_epochs(self, num_epoch, net):
start_time_train = datetime.datetime.now()
loss = None

for e in range(num_epoch):
loss, net = self.train(self.epoch_counter, net)
self.epoch_counter += 1
elapsed_time_train = datetime.datetime.now() - start_time_train
train_time_ms = int(elapsed_time_train.total_seconds() * 1000)

start_time_test = datetime.datetime.now()
accuracy, test_loss, class_precision, class_recall = self.test()
elapsed_time_test = datetime.datetime.now() - start_time_test
test_time_ms = int(elapsed_time_test.total_seconds() * 1000)

data = EpochData(self.epoch_counter, train_time_ms, test_time_ms, loss,
accuracy, test_loss, class_precision, class_recall, client_id=self.id)
self.epoch_results.append(data)

# # Copy GPU tensors to CPU
# for k, v in net.items():
# weights[k] = v.cpu()
return data, net

def save_model(self, epoch, suffix):
"""
Saves the model if necessary.
"""
self.args.get_logger().debug("Saving model to flat file storage. Save #{}", epoch)

if not os.path.exists(self.args.get_save_model_folder_path()):
os.mkdir(self.args.get_save_model_folder_path())

full_save_path = os.path.join(self.args.get_save_model_folder_path(),
"model_" + str(self.client_idx) + "_" + str(epoch) + "_" + suffix + ".model")
torch.save(self.get_nn_parameters(), full_save_path)

def calculate_class_precision(self, confusion_mat):
"""
Calculates the precision for each class from a confusion matrix.
"""
return np.diagonal(confusion_mat) / np.sum(confusion_mat, axis=0)

def calculate_class_recall(self, confusion_mat):
"""
Calculates the recall for each class from a confusion matrix.
"""
return np.diagonal(confusion_mat) / np.sum(confusion_mat, axis=1)

def get_client_datasize(self):
return len(self.dataset.get_train_sampler())

def __del__(self):
print(f'Client {self.id} is stopping')
Loading

0 comments on commit a103d67

Please sign in to comment.