Skip to content

Commit 95801a9

Browse files
committed
Worker API changes; updated example
1 parent 793fc32 commit 95801a9

28 files changed

+537
-403
lines changed

.babelrc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
{
22
"presets": ["@babel/preset-env"],
3-
"plugins": ["@babel/plugin-proposal-class-properties"]
3+
"plugins": [
4+
"@babel/plugin-proposal-class-properties"
5+
],
6+
"env": {
7+
"test": {
8+
"plugins": [
9+
"@babel/plugin-transform-runtime",
10+
"@babel/plugin-proposal-class-properties"
11+
]
12+
}
13+
}
414
}
1.19 MB
Binary file not shown.

examples/with-grid/data/tp_ops.pb

3.15 KB
Binary file not shown.

examples/with-grid/index.html

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
<!-- NOTE: TFJS version must match with one in package-lock.json -->
4343
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
4444
<script src="https://webrtc.github.io/adapter/adapter-latest.js"></script>
45+
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
4546
</head>
4647
<body>
4748
<img
@@ -72,9 +73,22 @@ <h1>syft.js/grid.js testing</h1>
7273
>.
7374
</p>
7475
<input type="text" id="grid-server" value="ws://localhost:3000" />
75-
<input type="text" id="protocol" value="10000000013" />
76-
<button id="connect">Connect to grid.js server</button>
76+
<!-- <input type="text" id="protocol" value="10000000013" />-->
77+
<!-- <button id="connect">Connect to grid.js server</button>-->
7778
<button id="start">Start FL Worker</button>
79+
80+
<div id="fl-training" style="display: none">
81+
<div style="display: table-row">
82+
<div style="display: table-cell">
83+
<div id="loss_graph"></div>
84+
</div>
85+
86+
<div style="display: table-cell">
87+
<div id="acc_graph"></div>
88+
</div>
89+
</div>
90+
</div>
91+
7892
<div id="app">
7993
<button id="disconnect">Disconnect</button>
8094
<p id="identity"></p>

examples/with-grid/index.js

Lines changed: 144 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import {
2424
} from './_helpers';
2525

2626
// In the real world: import syft from 'syft.js';
27-
import { Syft, GridAPIClient, SyftWorker } from '../../src';
27+
import { Syft } from '../../src';
2828
import { MnistData } from './mnist';
2929

3030
const gridServer = document.getElementById('grid-server');
@@ -38,6 +38,7 @@ const submitButton = document.getElementById('message-send');
3838

3939
appContainer.style.display = 'none';
4040

41+
/*
4142
connectButton.onclick = () => {
4243
appContainer.style.display = 'block';
4344
gridServer.style.display = 'none';
@@ -46,72 +47,170 @@ connectButton.onclick = () => {
4647
4748
startSyft(gridServer.value, protocol.value);
4849
};
50+
*/
4951

5052
startButton.onclick = () => {
53+
setFLUI();
5154
startFL(gridServer.value, 'model-id');
5255
};
5356

54-
const trainFLModel = async ({ job, model, clientConfig }) => {
55-
if (!job.plans.hasOwnProperty('training_plan')) {
56-
// no training plan, nothing to do
57-
return job.done();
58-
}
59-
60-
// load data
61-
console.log('Loading data...');
62-
const mnist = new MnistData();
63-
await mnist.load();
64-
const data = mnist.getTrainData();
65-
console.log('Data loaded');
66-
57+
const executeFLTrainingJob = async ({
58+
data,
59+
targets,
60+
job,
61+
model,
62+
clientConfig,
63+
callbacks
64+
}) => {
6765
const batchSize = clientConfig.batch_size;
68-
const batches = Math.ceil(data.xs.shape[0] / batchSize);
66+
const lr = clientConfig.lr;
67+
const numBatches = Math.ceil(data.shape[0] / batchSize);
6968
const maxEpochs = clientConfig.max_epochs || 1;
70-
const maxUpdates = clientConfig.max_updates || maxEpochs * batches;
71-
69+
const maxUpdates = clientConfig.max_updates || maxEpochs * numBatches;
7270
// set the lowest cap
73-
const updates = Math.min(maxUpdates, maxEpochs * batches);
74-
75-
for (let update = 0, batch = 0, epoch = 0; update < updates; update++) {
76-
const chunkSize = Math.min(batchSize, data.xs.shape[0] - batch * batchSize);
77-
const X_batch = data.xs.slice(batch * batchSize, chunkSize);
78-
const y_batch = data.labels.slice(batch * batchSize, chunkSize);
79-
console.log(
80-
`Epoch: ${epoch}, Batch: ${batch}: execute plan with`,
81-
model,
82-
X_batch,
83-
y_batch,
84-
clientConfig
71+
const numUpdates = Math.min(maxUpdates, maxEpochs * numBatches);
72+
73+
// Copy original model params.
74+
let modelParams = [];
75+
for (let param of model.params) {
76+
modelParams.push(param.clone());
77+
}
78+
79+
for (let update = 0, batch = 0, epoch = 0; update < numUpdates; update++) {
80+
const chunkSize = Math.min(batchSize, data.shape[0] - batch * batchSize);
81+
const dataBatch = data.slice(batch * batchSize, chunkSize);
82+
const targetBatch = targets.slice(batch * batchSize, chunkSize);
83+
84+
let [loss, acc, ...newModelParams] = await job.plans[
85+
'training_plan'
86+
].execute(
87+
job.worker,
88+
dataBatch,
89+
targetBatch,
90+
chunkSize,
91+
lr,
92+
...modelParams
8593
);
86-
// TODO plan execution
87-
// job.plans['training_plan'].execute();
88-
if (++batch === batches) {
89-
// full epoch
94+
95+
// Use updated model params in the next cycle.
96+
for (let i = 0; i < modelParams.length; i++) {
97+
modelParams[i].dispose();
98+
modelParams[i] = newModelParams[i];
99+
}
100+
101+
if (typeof callbacks.onBatchEnd === 'function') {
102+
callbacks.onBatchEnd({
103+
update,
104+
batch,
105+
epoch,
106+
accuracy: (await acc.data())[0],
107+
loss: (await loss.data())[0]
108+
});
109+
}
110+
111+
batch++;
112+
// check if we're out of batches (end of epoch)
113+
if (batch === numBatches) {
114+
if (typeof callbacks.onEpochEnd === 'function') {
115+
callbacks.onEpochEnd({ update, batch, epoch, model });
116+
}
90117
batch = 0;
91118
epoch++;
92-
console.log('Starting new epoch!');
93119
}
120+
121+
// free GPU memory
122+
acc.dispose();
123+
loss.dispose();
124+
dataBatch.dispose();
125+
targetBatch.dispose();
126+
}
127+
128+
// TODO protocol execution
129+
// job.protocols['secure_aggregation'].execute();
130+
131+
// Calc model diffs
132+
const modelDiff = [];
133+
for (let i = 0; i < modelParams.length; i++) {
134+
modelDiff.push(model.params[i].sub(modelParams[i]));
94135
}
95136

96-
if (job.protocols['secure_aggregation']) {
97-
// TODO protocol execution
98-
await job.report();
99-
} else {
100-
await job.report();
137+
// report
138+
await job.report(modelDiff);
139+
140+
if (typeof callbacks.onDone === 'function') {
141+
callbacks.onDone();
101142
}
102143
};
103144

104145
const startFL = async (url, modelId) => {
105-
const gridClient = new GridAPIClient({ url });
106-
const worker = await SyftWorker.create({ gridClient });
107-
const job = worker.newJob({ modelId });
146+
const worker = new Syft({ url, verbose: true });
147+
const job = await worker.newJob({ modelId });
108148
job.start();
109-
job.on('ready', trainFLModel);
110-
job.on('done', () => {
111-
console.log('done with the job!');
149+
job.on('ready', async ({ model, clientConfig }) => {
150+
// load data
151+
console.log('Loading data...');
152+
const mnist = new MnistData();
153+
await mnist.load();
154+
const data = mnist.getTrainData();
155+
console.log('Data loaded');
156+
157+
// train
158+
executeFLTrainingJob({
159+
model,
160+
data: data.xs,
161+
targets: data.labels,
162+
job,
163+
clientConfig,
164+
callbacks: {
165+
onBatchEnd: async ({ epoch, batch, accuracy, loss }) => {
166+
console.log(
167+
`Epoch: ${epoch}, Batch: ${batch}, Accuracy: ${accuracy}, Loss: ${loss}`
168+
);
169+
Plotly.extendTraces('loss_graph', { y: [[loss]] }, [0]);
170+
Plotly.extendTraces('acc_graph', { y: [[accuracy]] }, [0]);
171+
await tf.nextFrame();
172+
},
173+
onEpochEnd: ({ epoch }) => {
174+
console.log(`Epoch ${epoch} ended!`);
175+
},
176+
onDone: () => {
177+
console.log(`Job is done!`);
178+
}
179+
}
180+
});
112181
});
113182
};
114183

184+
const setFLUI = () => {
185+
Plotly.newPlot(
186+
'loss_graph',
187+
[
188+
{
189+
y: [],
190+
mode: 'lines',
191+
line: { color: '#80CAF6' }
192+
}
193+
],
194+
{ title: 'Train Loss', showlegend: false },
195+
{ staticPlot: true }
196+
);
197+
198+
Plotly.newPlot(
199+
'acc_graph',
200+
[
201+
{
202+
y: [],
203+
mode: 'lines',
204+
line: { color: '#80CAF6' }
205+
}
206+
],
207+
{ title: 'Train Accuracy', showlegend: false },
208+
{ staticPlot: true }
209+
);
210+
211+
document.getElementById('fl-training').style.display = 'table';
212+
};
213+
115214
const startSyft = (url, protocolId) => {
116215
const workerId = getQueryVariable('worker_id');
117216
const scopeId = getQueryVariable('scope_id');

examples/with-grid/mnist.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* =============================================================================
1616
*/
1717

18-
import * as tf from '@tensorflow/tfjs';
18+
import * as tf from '@tensorflow/tfjs-core';
1919

2020
export const IMAGE_H = 28;
2121
export const IMAGE_W = 28;

examples/with-grid/webpack.config.js

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ module.exports = (env, argv) => ({
3535
},
3636
plugins: [new HtmlWebpackPlugin({ template: './index.html' })],
3737
externals: {
38-
'@tensorflow/tfjs-core': 'tf',
39-
'@tensorflow/tfjs': 'tf'
38+
'@tensorflow/tfjs-core': 'tf'
4039
}
4140
});

package-lock.json

Lines changed: 49 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@
6767
"devDependencies": {
6868
"@babel/core": "^7.7.7",
6969
"@babel/plugin-proposal-class-properties": "^7.7.4",
70+
"@babel/plugin-transform-runtime": "^7.8.3",
7071
"@babel/preset-env": "^7.7.7",
72+
"@babel/runtime": "^7.8.4",
7173
"@joseph184/rollup-plugin-node-builtins": "^2.1.4",
7274
"@tensorflow/tfjs-core": "^1.2.5",
7375
"auto-changelog": "^1.16.2",

src/_constants.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ export const SOCKET_PING = 'socket-ping';
44

55
// Grid
66
export const GET_PROTOCOL = 'get-protocol';
7+
export const CYCLE_STATUS_ACCEPTED = 'accepted';
8+
export const CYCLE_STATUS_REJECTED = 'rejected';
79

810
// WebRTC
911
export const WEBRTC_JOIN_ROOM = 'webrtc: join-room';

0 commit comments

Comments
 (0)