@@ -24,7 +24,7 @@ import {
24
24
} from './_helpers' ;
25
25
26
26
// In the real world: import syft from 'syft.js';
27
- import { Syft , GridAPIClient , SyftWorker } from '../../src' ;
27
+ import { Syft } from '../../src' ;
28
28
import { MnistData } from './mnist' ;
29
29
30
30
const gridServer = document . getElementById ( 'grid-server' ) ;
@@ -38,6 +38,7 @@ const submitButton = document.getElementById('message-send');
38
38
39
39
appContainer . style . display = 'none' ;
40
40
41
+ /*
41
42
connectButton.onclick = () => {
42
43
appContainer.style.display = 'block';
43
44
gridServer.style.display = 'none';
@@ -46,72 +47,170 @@ connectButton.onclick = () => {
46
47
47
48
startSyft(gridServer.value, protocol.value);
48
49
};
50
+ */
49
51
50
52
startButton . onclick = ( ) => {
53
+ setFLUI ( ) ;
51
54
startFL ( gridServer . value , 'model-id' ) ;
52
55
} ;
53
56
54
- const trainFLModel = async ( { job, model, clientConfig } ) => {
55
- if ( ! job . plans . hasOwnProperty ( 'training_plan' ) ) {
56
- // no training plan, nothing to do
57
- return job . done ( ) ;
58
- }
59
-
60
- // load data
61
- console . log ( 'Loading data...' ) ;
62
- const mnist = new MnistData ( ) ;
63
- await mnist . load ( ) ;
64
- const data = mnist . getTrainData ( ) ;
65
- console . log ( 'Data loaded' ) ;
66
-
57
+ const executeFLTrainingJob = async ( {
58
+ data,
59
+ targets,
60
+ job,
61
+ model,
62
+ clientConfig,
63
+ callbacks
64
+ } ) => {
67
65
const batchSize = clientConfig . batch_size ;
68
- const batches = Math . ceil ( data . xs . shape [ 0 ] / batchSize ) ;
66
+ const lr = clientConfig . lr ;
67
+ const numBatches = Math . ceil ( data . shape [ 0 ] / batchSize ) ;
69
68
const maxEpochs = clientConfig . max_epochs || 1 ;
70
- const maxUpdates = clientConfig . max_updates || maxEpochs * batches ;
71
-
69
+ const maxUpdates = clientConfig . max_updates || maxEpochs * numBatches ;
72
70
// set the lowest cap
73
- const updates = Math . min ( maxUpdates , maxEpochs * batches ) ;
74
-
75
- for ( let update = 0 , batch = 0 , epoch = 0 ; update < updates ; update ++ ) {
76
- const chunkSize = Math . min ( batchSize , data . xs . shape [ 0 ] - batch * batchSize ) ;
77
- const X_batch = data . xs . slice ( batch * batchSize , chunkSize ) ;
78
- const y_batch = data . labels . slice ( batch * batchSize , chunkSize ) ;
79
- console . log (
80
- `Epoch: ${ epoch } , Batch: ${ batch } : execute plan with` ,
81
- model ,
82
- X_batch ,
83
- y_batch ,
84
- clientConfig
71
+ const numUpdates = Math . min ( maxUpdates , maxEpochs * numBatches ) ;
72
+
73
+ // Copy original model params.
74
+ let modelParams = [ ] ;
75
+ for ( let param of model . params ) {
76
+ modelParams . push ( param . clone ( ) ) ;
77
+ }
78
+
79
+ for ( let update = 0 , batch = 0 , epoch = 0 ; update < numUpdates ; update ++ ) {
80
+ const chunkSize = Math . min ( batchSize , data . shape [ 0 ] - batch * batchSize ) ;
81
+ const dataBatch = data . slice ( batch * batchSize , chunkSize ) ;
82
+ const targetBatch = targets . slice ( batch * batchSize , chunkSize ) ;
83
+
84
+ let [ loss , acc , ...newModelParams ] = await job . plans [
85
+ 'training_plan'
86
+ ] . execute (
87
+ job . worker ,
88
+ dataBatch ,
89
+ targetBatch ,
90
+ chunkSize ,
91
+ lr ,
92
+ ...modelParams
85
93
) ;
86
- // TODO plan execution
87
- // job.plans['training_plan'].execute();
88
- if ( ++ batch === batches ) {
89
- // full epoch
94
+
95
+ // Use updated model params in the next cycle.
96
+ for ( let i = 0 ; i < modelParams . length ; i ++ ) {
97
+ modelParams [ i ] . dispose ( ) ;
98
+ modelParams [ i ] = newModelParams [ i ] ;
99
+ }
100
+
101
+ if ( typeof callbacks . onBatchEnd === 'function' ) {
102
+ callbacks . onBatchEnd ( {
103
+ update,
104
+ batch,
105
+ epoch,
106
+ accuracy : ( await acc . data ( ) ) [ 0 ] ,
107
+ loss : ( await loss . data ( ) ) [ 0 ]
108
+ } ) ;
109
+ }
110
+
111
+ batch ++ ;
112
+ // check if we're out of batches (end of epoch)
113
+ if ( batch === numBatches ) {
114
+ if ( typeof callbacks . onEpochEnd === 'function' ) {
115
+ callbacks . onEpochEnd ( { update, batch, epoch, model } ) ;
116
+ }
90
117
batch = 0 ;
91
118
epoch ++ ;
92
- console . log ( 'Starting new epoch!' ) ;
93
119
}
120
+
121
+ // free GPU memory
122
+ acc . dispose ( ) ;
123
+ loss . dispose ( ) ;
124
+ dataBatch . dispose ( ) ;
125
+ targetBatch . dispose ( ) ;
126
+ }
127
+
128
+ // TODO protocol execution
129
+ // job.protocols['secure_aggregation'].execute();
130
+
131
+ // Calc model diffs
132
+ const modelDiff = [ ] ;
133
+ for ( let i = 0 ; i < modelParams . length ; i ++ ) {
134
+ modelDiff . push ( model . params [ i ] . sub ( modelParams [ i ] ) ) ;
94
135
}
95
136
96
- if ( job . protocols [ 'secure_aggregation' ] ) {
97
- // TODO protocol execution
98
- await job . report ( ) ;
99
- } else {
100
- await job . report ( ) ;
137
+ // report
138
+ await job . report ( modelDiff ) ;
139
+
140
+ if ( typeof callbacks . onDone === 'function' ) {
141
+ callbacks . onDone ( ) ;
101
142
}
102
143
} ;
103
144
104
145
const startFL = async ( url , modelId ) => {
105
- const gridClient = new GridAPIClient ( { url } ) ;
106
- const worker = await SyftWorker . create ( { gridClient } ) ;
107
- const job = worker . newJob ( { modelId } ) ;
146
+ const worker = new Syft ( { url, verbose : true } ) ;
147
+ const job = await worker . newJob ( { modelId } ) ;
108
148
job . start ( ) ;
109
- job . on ( 'ready' , trainFLModel ) ;
110
- job . on ( 'done' , ( ) => {
111
- console . log ( 'done with the job!' ) ;
149
+ job . on ( 'ready' , async ( { model, clientConfig } ) => {
150
+ // load data
151
+ console . log ( 'Loading data...' ) ;
152
+ const mnist = new MnistData ( ) ;
153
+ await mnist . load ( ) ;
154
+ const data = mnist . getTrainData ( ) ;
155
+ console . log ( 'Data loaded' ) ;
156
+
157
+ // train
158
+ executeFLTrainingJob ( {
159
+ model,
160
+ data : data . xs ,
161
+ targets : data . labels ,
162
+ job,
163
+ clientConfig,
164
+ callbacks : {
165
+ onBatchEnd : async ( { epoch, batch, accuracy, loss } ) => {
166
+ console . log (
167
+ `Epoch: ${ epoch } , Batch: ${ batch } , Accuracy: ${ accuracy } , Loss: ${ loss } `
168
+ ) ;
169
+ Plotly . extendTraces ( 'loss_graph' , { y : [ [ loss ] ] } , [ 0 ] ) ;
170
+ Plotly . extendTraces ( 'acc_graph' , { y : [ [ accuracy ] ] } , [ 0 ] ) ;
171
+ await tf . nextFrame ( ) ;
172
+ } ,
173
+ onEpochEnd : ( { epoch } ) => {
174
+ console . log ( `Epoch ${ epoch } ended!` ) ;
175
+ } ,
176
+ onDone : ( ) => {
177
+ console . log ( `Job is done!` ) ;
178
+ }
179
+ }
180
+ } ) ;
112
181
} ) ;
113
182
} ;
114
183
184
+ const setFLUI = ( ) => {
185
+ Plotly . newPlot (
186
+ 'loss_graph' ,
187
+ [
188
+ {
189
+ y : [ ] ,
190
+ mode : 'lines' ,
191
+ line : { color : '#80CAF6' }
192
+ }
193
+ ] ,
194
+ { title : 'Train Loss' , showlegend : false } ,
195
+ { staticPlot : true }
196
+ ) ;
197
+
198
+ Plotly . newPlot (
199
+ 'acc_graph' ,
200
+ [
201
+ {
202
+ y : [ ] ,
203
+ mode : 'lines' ,
204
+ line : { color : '#80CAF6' }
205
+ }
206
+ ] ,
207
+ { title : 'Train Accuracy' , showlegend : false } ,
208
+ { staticPlot : true }
209
+ ) ;
210
+
211
+ document . getElementById ( 'fl-training' ) . style . display = 'table' ;
212
+ } ;
213
+
115
214
const startSyft = ( url , protocolId ) => {
116
215
const workerId = getQueryVariable ( 'worker_id' ) ;
117
216
const scopeId = getQueryVariable ( 'scope_id' ) ;
0 commit comments