Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions FL.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@

#include <atomic>
#include <chrono>
#include <cmath>
#include <iostream>
#include <mutex>
#include <numeric>
#include <random>
#include <thread>
#include <vector>
static std::mt19937 rng(std::random_device{}());
static std::mutex rng_mtx;

/**
* �����˽�����ࣺ���Ӹ�˹���������-�����˽
*/
class DifferentialPrivacy {
private:
double epsilon; // ��˽Ԥ��
double sensitivity; // �ݶ����ж�
public:
DifferentialPrivacy(double eps = 1.0, double sens = 0.1)
: epsilon(eps), sensitivity(sens) {}

// ���Ӹ�˹����
double add_gaussian_noise(double value) {
std::lock_guard<std::mutex> lock(rng_mtx);
std::normal_distribution<double> dist(0.0, sensitivity / epsilon);
return value + dist(rng);
}
};

/**
* ����ѧϰ�ͻ����ࣺģ�ⵥһ���������뱾��ѵ��
*/
class FLClient {
private:
int client_id;
std::vector<std::vector<double>> X; // �������� [������, ������]
std::vector<int> y; // ��ǩ��0/1��
std::vector<double> local_weights; // ����ģ��Ȩ��
double lr; // ѧϰ��
DifferentialPrivacy dp; // �����˽ʵ��

// Sigmoid�����
double sigmoid(double z) {
return 1.0 / (1.0 + exp(-std::clamp(z, -500.0, 500.0))); // ��ֹ���
}

// ����Ӧѧϰ�ʣ����������˥����
double adaptive_lr(int iter) { return lr / (1.0 + 0.01 * iter); }

public:
FLClient(int id, const std::vector<std::vector<double>>& features,
const std::vector<int>& labels, double learning_rate = 0.01)
: client_id(id),
X(features),
y(labels),
lr(learning_rate),
dp(1.0, 0.1) {
// ��ʼ��Ȩ�أ���ƫ���Ȩ��ά��=������+1��
local_weights.resize(X[0].size() + 1, 0.0);
std::lock_guard<std::mutex> lock(rng_mtx);
std::uniform_real_distribution<double> dist(-0.1, 0.1);
for (auto& w : local_weights) w = dist(rng);
}

// ���ص���ѵ�������ش������˽���ݶ�
std::vector<double> local_train(int iter) {
std::vector<double> grad(local_weights.size(), 0.0);
int n_samples = X.size();
double current_lr = adaptive_lr(iter);

// ���������ݶȣ�ȫ��������
for (int i = 0; i < n_samples; ++i) {
// ����Ԥ��ֵ��w0*1 + w1*x1 + w2*x2 + ...
double y_pred = local_weights[0]; // ƫ����
for (int j = 0; j < X[i].size(); ++j) {
y_pred += local_weights[j + 1] * X[i][j];
}
y_pred = sigmoid(y_pred);

// �ݶȼ��㣨�߼��ع齻������ʧ��
double error = y_pred - y[i];
grad[0] += error; // ƫ�����ݶ�
for (int j = 0; j < X[i].size(); ++j) {
grad[j + 1] += error * X[i][j];
}
}

// �ݶȹ�һ��+�����˽
for (auto& g : grad) {
g /= n_samples; // ƽ���ݶ�
g = dp.add_gaussian_noise(g); // ��������
// �ݶȲü�����ֹ�ݶȱ�ը��
g = std::clamp(g, -1.0, 1.0);
}

// ���±���Ȩ��
for (int j = 0; j < local_weights.size(); ++j) {
local_weights[j] -= current_lr * grad[j];
}

return grad;
}

int get_id() const { return client_id; }
const std::vector<double>& get_local_weights() const {
return local_weights;
}
};

/**
* ����ѧϰ����ˣ��ۺϿͻ����ݶȣ�����ȫ��ģ��
*/
class FLServer {
private:
std::vector<double> global_weights; // ȫ��ģ��Ȩ��
std::atomic<int> received_clients; // �ѽ����ݶȵĿͻ�����
std::mutex mtx; // �ۺ���
int total_clients; // �ܿͻ�����
double aggregation_ratio; // �첽�ۺϱ�������0.8��ʾ80%�ͻ�����Ӧ���ɾۺϣ�

public:
FLServer(int feature_dim, int total_clients_num, double agg_ratio = 0.8)
: total_clients(total_clients_num),
aggregation_ratio(agg_ratio),
received_clients(0) {
// ��ʼ��ȫ��Ȩ��
global_weights.resize(feature_dim + 1, 0.0);
std::lock_guard<std::mutex> lock(rng_mtx);
std::uniform_real_distribution<double> dist(-0.1, 0.1);
for (auto& w : global_weights) w = dist(rng);
}

// ���տͻ����ݶȲ��첽�ۺ�
void receive_and_aggregate(const std::vector<double>& client_grad,
int client_id) {
std::lock_guard<std::mutex> lock(mtx);
received_clients++;

// �ݶȾۺϣ���Ȩƽ������Ϊ��Ȩ��
double agg_weight = 1.0 / total_clients;
for (int j = 0; j < global_weights.size(); ++j) {
global_weights[j] -= agg_weight * client_grad[j];
}

std::cout << "[Server] ���տͻ���" << client_id << "�ݶȣ��ѽ��գ�"
<< received_clients << "/" << total_clients << std::endl;

// �첽�ۺϴ������ﵽ�ۺϱ��������ü���
if (received_clients >= total_clients * aggregation_ratio) {
std::cout << "[Server] �ﵽ�ۺ���ֵ������ȫ��ģ��" << std::endl;
received_clients = 0;
}
}

// ��ͻ����·�ȫ��Ȩ��
std::vector<double> get_global_weights() const {
std::lock_guard<std::mutex> lock(mtx);
return global_weights;
}

// ģ��������׼ȷ�ʣ�
double evaluate(const std::vector<std::vector<double>>& X_test,
const std::vector<int>& y_test) {
std::vector<double> weights = get_global_weights();
int correct = 0;
auto sigmoid = [](double z) {
return 1.0 / (1.0 + exp(-std::clamp(z, -500.0, 500.0)));
};

for (int i = 0; i < X_test.size(); ++i) {
double y_pred = weights[0];
for (int j = 0; j < X_test[i].size(); ++j) {
y_pred += weights[j + 1] * X_test[i][j];
}
y_pred = sigmoid(y_pred);
correct += (y_pred >= 0.5) == (y_test[i] == 1) ? 1 : 0;
}
return static_cast<double>(correct) / X_test.size();
}
};

/**
* ����ģ�����ݼ�������������
*/
std::pair<std::vector<std::vector<double>>, std::vector<int>> generate_sim_data(
int n_samples, int n_features) {
std::vector<std::vector<double>> X(n_samples,
std::vector<double>(n_features));
std::vector<int> y(n_samples);

// ������������̬�ֲ���
std::lock_guard<std::mutex> lock(rng_mtx);
std::normal_distribution<double> x_dist(0.0, 1.0);
std::bernoulli_distribution y_dist(0.5);

for (int i = 0; i < n_samples; ++i) {
for (int j = 0; j < n_features; ++j) {
X[i][j] = x_dist(rng);
}
y[i] = y_dist(rng);
}
return {X, y};
}

/**
* �ͻ���ѵ���̺߳���
*/
void client_train_thread(FLClient& client, FLServer& server, int epochs) {
for (int e = 0; e < epochs; ++e) {
// ����ѵ��
std::vector<double> grad = client.local_train(e);
// �ϱ��ݶȵ������
server.receive_and_aggregate(grad, client.get_id());
// ��ȡ����ȫ��Ȩ�أ�ģ���첽ͨ�ţ�
client = FLClient(client.get_id(), client.X, client.y,
0.01); // �򻯣����³�ʼ���ͻ���Ȩ��Ϊȫ��Ȩ��
// ģ�������ӳ�
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}

int main() {
// 1. ����ʵ�����
const int n_clients = 5; // ����ͻ�����
const int n_features = 5; // ������
const int n_samples_per_client = 100; // ÿ���ͻ���������
const int epochs = 20; // ѵ������

// 2. ����ģ�����ݣ��ͻ���+���Լ���
std::vector<FLClient> clients;
for (int i = 0; i < n_clients; ++i) {
auto [X, y] = generate_sim_data(n_samples_per_client, n_features);
clients.emplace_back(i, X, y, 0.01);
}
// ���Լ�
auto [X_test, y_test] = generate_sim_data(200, n_features);

// 3. ��ʼ�������
FLServer server(n_features, n_clients, 0.8);

// 4. �����ͻ���ѵ���߳�
std::vector<std::thread> client_threads;
for (auto& client : clients) {
client_threads.emplace_back(client_train_thread, std::ref(client),
std::ref(server), epochs);
}

// 5. �ȴ����пͻ���ѵ�����
for (auto& t : client_threads) {
t.join();
}

// 6. �����
double accuracy = server.evaluate(X_test, y_test);
std::cout << "\n[Final] ȫ��ģ�Ͳ���׼ȷ�ʣ�" << accuracy << std::endl;

return 0;
}