Skip to content

Commit 6e85f6f

Browse files
committed
discojs: add docs
1 parent 021f499 commit 6e85f6f

File tree

27 files changed

+475
-88
lines changed

27 files changed

+475
-88
lines changed

discojs/discojs-core/src/aggregator/base.ts

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ export enum AggregationStep {
1010
AGGREGATE
1111
}
1212

13+
/**
14+
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
15+
* a result based off their aggregation, whenever some defined condition is met.
16+
*/
1317
export abstract class Base<T> {
1418
/**
1519
* Contains the ids of all active nodes, i.e. members of the aggregation group at
@@ -32,44 +36,67 @@ export abstract class Base<T> {
3236
protected informant?: AsyncInformant<T>
3337
/**
3438
* The result promise which, on resolve, will contain the current aggregation result.
39+
* This promise should be fetched by any object making use of an aggregator, in order
40+
* to await upon aggregation.
3541
*/
3642
protected result: Promise<T>
3743
/**
38-
* The current aggregation round, used for assessing whether a contribution is recent enough
44+
* The current aggregation round, used for assessing whether a node contribution is recent enough
3945
* or not.
4046
*/
4147
protected _round = 0
42-
48+
/**
49+
* The current communication round. A single aggregation round is made of possibly multiple
50+
* communication rounds. This makes the aggregator free to perform intermediate aggregation
51+
* steps based off communication with its nodes. Overall, this allows for more complex
52+
* aggregation schemes requiring an exchange of information between nodes before aggregating.
53+
*/
4354
protected _communicationRound = 0
4455

4556
constructor (
57+
/**
58+
* The task for which the aggregator should be created.
59+
*/
4660
public readonly task: Task,
61+
/**
62+
* The TF.js model whose weights are updated on aggregation.
63+
*/
4764
protected _model?: tf.LayersModel,
65+
/**
66+
* The round cut-off for contributions.
67+
*/
4868
protected readonly roundCutoff = 0,
69+
/**
70+
* The number of communication rounds occuring during any given aggregation round.
71+
*/
4972
public readonly communicationRounds = 1
5073
) {
5174
this.eventEmitter = new EventEmitter()
5275
this.contributions = Map()
5376
this._nodes = Set()
5477

78+
// Make the initial result promise
5579
this.result = this.makeResult()
5680

81+
// On every aggregation, update the object's state to match the current aggregation
82+
// and communication rounds.
5783
this.eventEmitter.on('aggregation', () => {
5884
this.nextRound()
5985
})
6086
}
6187

6288
/**
63-
* Adds a node's contribution to the aggregator for a given round.
64-
* The contribution will be aggregated during the round's aggregation step.
89+
* Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
90+
* The contribution will be aggregated during the next aggregation step.
6591
* @param nodeId The node's id
6692
* @param contribution The node's contribution
67-
* @param round For which round the contribution was made
93+
* @param round For which aggregation round the contribution was made
94+
* @param communicationRound For which communication round the contribution was made
6895
*/
6996
abstract add (nodeId: client.NodeID, contribution: T, round: number, communicationRound?: number): boolean
7097

7198
/**
72-
* Performs the aggregation step over the received node contributions.
99+
* Performs an aggregation step over the received node contributions.
73100
* Must store the aggregation's result in the aggregator's result promise.
74101
*/
75102
abstract aggregate (): void
@@ -110,6 +137,10 @@ export abstract class Base<T> {
110137
}
111138
}
112139

140+
/**
141+
* Sets the aggregator's TF.js model.
142+
* @param model The new TF.js model
143+
*/
113144
setModel (model: tf.LayersModel): void {
114145
this._model = model
115146
}
@@ -138,6 +169,10 @@ export abstract class Base<T> {
138169
this._nodes = nodeIds
139170
}
140171

172+
/**
173+
* Empties the current set of "nodes". Usually called at the end of an aggregation round,
174+
* if the set of nodes is meant to change or to be actualized.
175+
*/
141176
resetNodes (): void {
142177
this._nodes = Set()
143178
}
@@ -163,7 +198,9 @@ export abstract class Base<T> {
163198
}
164199

165200
/**
166-
* Resets the aggregator's step and prepares it for the next aggregation round.
201+
* Updates the aggregator's state to proceed to the next communication round.
202+
* If all communication rounds were performed, proceeds to the next aggregation round
203+
* and empties the collection of stored contributions.
167204
*/
168205
public nextRound (): void {
169206
if (++this._communicationRound === this.communicationRounds) {
@@ -184,10 +221,9 @@ export abstract class Base<T> {
184221
}
185222

186223
/**
187-
* The aggregation result can be awaited upon in an asynchronous fashion, to allow
188-
* for the receipt of contributions while performing other tasks. This function
189-
* gives access to the current aggregation result's promise, which will eventually
190-
* resolve and contain the result of the very next aggregation step, at the
224+
* Aggregation steps are performed asynchronously, yet can be awaited upon when required.
225+
* This function gives access to the current aggregation result's promise, which will
226+
* eventually resolve and contain the result of the very next aggregation step, at the
191227
* time of the function call.
192228
* @returns The promise containing the aggregation result
193229
*/
@@ -196,7 +232,7 @@ export abstract class Base<T> {
196232
}
197233

198234
/**
199-
* Constructs the payload sent to other nodes as contribution.
235+
* Constructs the payloads sent to other nodes as contribution.
200236
* @param base Object from which the payload is computed
201237
*/
202238
abstract makePayloads (base: T): Map<client.NodeID, T>
@@ -218,8 +254,8 @@ export abstract class Base<T> {
218254
}
219255

220256
/**
221-
* The aggregator's current size, defined by its amount of contributions.
222-
* The size is bounded by the amount of all active nodes.
257+
* The aggregator's current size, defined by its number of contributions. The size is bounded by
258+
* the amount of all active nodes times the number of communication rounds.
223259
*/
224260
get size (): number {
225261
return this.contributions
@@ -235,6 +271,9 @@ export abstract class Base<T> {
235271
return this._model
236272
}
237273

274+
/**
275+
* The current commnication round.
276+
*/
238277
get communicationRound (): number {
239278
return this._communicationRound
240279
}

discojs/discojs-core/src/aggregator/get.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
import { aggregator, Task } from '..'
22

3+
/**
4+
* Enumeration of the available types of aggregator.
5+
*/
36
export enum AggregatorChoice {
47
MEAN,
58
ROBUST,
69
SECURE,
710
BANDIT
811
}
912

13+
/**
14+
* Provides the aggregator object adequate to the given task.
15+
* @param task The task
16+
* @returns The aggregator
17+
*/
1018
export function getAggregator (task: Task): aggregator.Aggregator {
1119
const error = new Error('not implemented')
1220
switch (task.trainingInformation.aggregator) {

discojs/discojs-core/src/aggregator/mean.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@ import { AggregationStep, Base as Aggregator } from './base'
44
import { Task, WeightsContainer, aggregation, tf, client } from '..'
55

66
/**
7-
* Aggregator that computes the mean of the weights received from the nodes.
7+
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
88
*/
99
export class MeanAggregator extends Aggregator<WeightsContainer> {
10+
/**
11+
* The threshold t to fulfill to trigger an aggregation step. It can either be:
12+
* - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
13+
* - absolute: t > 1, thus requiring t contributions
14+
*/
1015
public readonly threshold: number
1116

1217
constructor (
@@ -17,17 +22,24 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
1722
) {
1823
super(task, model, roundCutoff, 1)
1924

25+
// Default threshold is 100% of node participation
2026
if (threshold === undefined) {
2127
this.threshold = 1
28+
// Threshold must be positive
2229
} else if (threshold <= 0) {
2330
throw new Error('threshold must be positive')
31+
// Thresholds greater than 1 are considered absolute instead of relative to the number of nodes
2432
} else if (threshold > 1 && Math.round(threshold) !== threshold) {
2533
throw new Error('absolute thresholds must integers')
2634
} else {
2735
this.threshold = threshold
2836
}
2937
}
3038

39+
/**
40+
* Checks whether the contributions buffer is full, according to the set threshold.
41+
* @returns Whether the contributions buffer is full
42+
*/
3143
isFull (): boolean {
3244
if (this.threshold <= 1) {
3345
const contribs = this.contributions.get(this.communicationRound)

discojs/discojs-core/src/aggregator/secure.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ import * as crypto from 'crypto'
66
import { Map, List, Range } from 'immutable'
77

88
/**
9-
* Received contributions are the nodes' partial sums. The payloads are our random additive shares.
9+
* Aggregator implementing secure multi-party computation for decentralized learning.
10+
* An aggregation is made of two communication rounds:
11+
* - first, nodes communicate their random shares to each other;
12+
* - then, they sum their received shares and communicate the result.
13+
* Finally, nodes are able to average the received partial sums to establish the aggregation result.
1014
*/
1115
export class SecureAggregator extends Aggregator<WeightsContainer> {
1216
public static readonly MAX_SEED: number = 2 ** 47
@@ -71,7 +75,7 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {
7175
}
7276

7377
/**
74-
* Generate N additive shares that aggregate to the secret weights array, where N is the number of peers
78+
* Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
7579
*/
7680
public generateAllShares (secret: WeightsContainer): List<WeightsContainer> {
7781
if (this.nodes.size === 0) {
@@ -86,16 +90,12 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {
8690
}
8791

8892
/**
89-
* Generates one share in the same shape as the secret that is populated with values randomly chosend from
93+
* Generates one share in the same shape as the secret that is populated with values randomly chosen from
9094
* a uniform distribution between (-maxShareValue, maxShareValue).
9195
*/
9296
public generateRandomShare (secret: WeightsContainer): WeightsContainer {
9397
const seed = crypto.randomInt(SecureAggregator.MAX_SEED)
9498
return secret.map((t) =>
9599
tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue, 'float32', seed))
96100
}
97-
98-
get communicationRound (): number {
99-
return this._communicationRound
100-
}
101101
}

discojs/discojs-core/src/client/base.ts

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,53 @@ import { NodeID } from './types'
66
import { EventConnection } from './event_connection'
77
import { Aggregator } from '../aggregator'
88

9+
/**
10+
* Main, abstract, class representing a Disco client in a network, which handles
11+
* communication with other nodes, be it peers or a server.
12+
*/
913
export abstract class Base {
14+
/**
15+
* Own ID provided by the network's server.
16+
*/
1017
protected _ownId?: NodeID
18+
/**
19+
* The network's server.
20+
*/
1121
protected _server?: EventConnection
22+
/**
23+
* The aggregator's result produced after aggregation.
24+
*/
1225
protected aggregationResult?: Promise<WeightsContainer>
1326

1427
constructor (
28+
/**
29+
* The network server's URL to connect to.
30+
*/
1531
public readonly url: URL,
32+
/**
33+
* The client's corresponding task.
34+
*/
1635
public readonly task: Task,
36+
/**
37+
* The client's aggregator.
38+
*/
1739
public readonly aggregator: Aggregator
1840
) {}
1941

2042
/**
21-
* Handles the connection process from the client to any sort of
22-
* centralized server.
43+
* Handles the connection process from the client to any sort of network server.
2344
*/
2445
async connect (): Promise<void> {}
2546

2647
/**
27-
* Handles the disconnection process of the client from any sort
28-
* of centralized server.
48+
* Handles the disconnection process of the client from any sort of network server.
2949
*/
3050
async disconnect (): Promise<void> {}
3151

52+
/**
53+
* Fetches the latest model available on the network's server, for the adequate task.
54+
* @returns The latest model
55+
*/
3256
async getLatestModel (): Promise<tf.LayersModel> {
3357
const url = new URL('', this.url.href)
3458
if (!url.pathname.endsWith('/')) {
@@ -41,29 +65,43 @@ export abstract class Base {
4165
return await serialization.model.decode(response.data)
4266
}
4367

68+
/**
69+
* Communication callback called once at the beginning of the training instance.
70+
* @param weights The initial model weights
71+
* @param trainingInformant The training informant
72+
*/
4473
async onTrainBeginCommunication (
4574
weights: WeightsContainer,
4675
trainingInformant: TrainingInformant
4776
): Promise<void> {}
4877

4978
/**
50-
* The training manager matches this function with the training loop's
51-
* onTrainEnd callback when training a TFJS model object. See the
52-
* training manager for more details.
79+
* Communication callback called once at the end of the training instance.
80+
* @param weights The final model weights
81+
* @param trainingInformant The training informant
5382
*/
5483
async onTrainEndCommunication (
5584
weights: WeightsContainer,
5685
trainingInformant: TrainingInformant
5786
): Promise<void> {}
5887

88+
/**
89+
* Communication callback called at the beginning of every training round.
90+
* @param weights The most recent local weight updates
91+
* @param round The current training round
92+
* @param trainingInformant The training informant
93+
*/
5994
async onRoundBeginCommunication (
6095
weights: WeightsContainer,
6196
round: number,
6297
trainingInformant: TrainingInformant
6398
): Promise<void> {}
6499

65100
/**
66-
* This function will be called whenever a local round has ended.
101+
* Communication callback called the end of every training round.
102+
* @param weights The most recent local weight updates
103+
* @param round The current training round
104+
* @param trainingInformant The training informant
67105
*/
68106
async onRoundEndCommunication (
69107
weights: WeightsContainer,

0 commit comments

Comments
 (0)