diff --git a/.github/workflows/lint-test-build.yml b/.github/workflows/lint-test-build.yml index fbca70860..5228ed62c 100644 --- a/.github/workflows/lint-test-build.yml +++ b/.github/workflows/lint-test-build.yml @@ -8,16 +8,18 @@ env: node_version: 16 jobs: - download-training-data: + download-datasets: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + with: + lfs: true + submodules: true - uses: actions/cache@v3 with: - path: example_training_data - key: training_data - - run: ./get_training_data.sh - working-directory: ./ + path: datasets + key: datasets-${{ hashFiles('datasets/**') }} + - run: ./datasets/populate lint-lib-core: needs: [build-lib-core, build-lib-node] @@ -213,19 +215,17 @@ jobs: test-lib-core: needs: - [ - build-lib-core, - build-lib-node, - build-server-docker, - download-training-data, - ] + [build-lib-core, build-lib-node, build-server-docker, download-datasets] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + with: + lfs: true + submodules: true - uses: actions/cache@v3 with: - path: example_training_data - key: training_data + path: datasets + key: datasets-${{ hashFiles('datasets/**') }} - uses: actions/setup-node@v3 with: node-version: ${{ env.node_version }} @@ -235,14 +235,17 @@ jobs: - run: ./with_server npm --workspace=./discojs/discojs-core test test-lib-node: - needs: [build-lib-core, build-server-docker, download-training-data] + needs: [build-lib-core, build-server-docker, download-datasets] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + with: + lfs: true + submodules: true - uses: actions/cache@v3 with: - path: example_training_data - key: training_data + path: datasets + key: datasets-${{ hashFiles('datasets/**') }} - uses: actions/setup-node@v3 with: node-version: ${{ env.node_version }} @@ -252,14 +255,17 @@ jobs: - run: ./with_server npm --workspace=./discojs/discojs-node test test-lib-web: - needs: [build-lib-core, build-server-docker, download-training-data] + needs: [build-lib-core, build-server-docker, download-datasets] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + with: + lfs: true + submodules: true - uses: actions/cache@v3 with: - path: example_training_data - key: training_data + path: datasets + key: datasets-${{ hashFiles('datasets/**') }} - uses: actions/setup-node@v3 with: node-version: ${{ env.node_version }} @@ -269,14 +275,17 @@ jobs: - run: ./with_server npm --workspace=./discojs/discojs-web test test-server: - needs: [build-lib-core, build-lib-node, download-training-data] + needs: [build-lib-core, build-lib-node, download-datasets] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + with: + lfs: true + submodules: true - uses: actions/cache@v3 with: - path: example_training_data - key: training_data + path: datasets + key: datasets-${{ hashFiles('datasets/**') }} - uses: actions/setup-node@v3 with: node-version: ${{ env.node_version }} @@ -286,14 +295,17 @@ jobs: - run: npm --workspace=./server test test-web-client: - needs: [build-lib-core, build-lib-web, download-training-data] + needs: [build-lib-core, build-lib-web, download-datasets] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + with: + lfs: true + submodules: true - uses: actions/cache@v3 with: - path: example_training_data - key: training_data + path: datasets + key: datasets-${{ hashFiles('datasets/**') }} - uses: actions/setup-node@v3 with: node-version: ${{ env.node_version }} @@ -310,15 +322,17 @@ jobs: config: baseUrl=http://localhost:8081/#/ test-cli: - needs: - [build-lib-core, build-lib-node, build-server, download-training-data] + needs: [build-lib-core, build-lib-node, build-server, download-datasets] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + with: + lfs: true + submodules: true - uses: actions/cache@v3 with: - path: example_training_data - key: training_data + path: datasets + key: datasets-${{ hashFiles('datasets/**') }} - uses: actions/setup-node@v3 with: node-version: ${{ env.node_version }} @@ -328,15 +342,17 @@ jobs: - run: npm --workspace=./cli start -- -t cifar10 -u 1 -e 1 test-docs-examples: - needs: - [build-lib-core, build-lib-node, build-server, download-training-data] + needs: [build-lib-core, build-lib-node, build-server, download-datasets] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + with: + lfs: true + submodules: true - uses: actions/cache@v3 with: - path: example_training_data - key: training_data + path: datasets + key: datasets-${{ hashFiles('datasets/**') }} - uses: actions/setup-node@v3 with: node-version: ${{ env.node_version }} diff --git a/.gitignore b/.gitignore index af99da910..1dd81c891 100644 --- a/.gitignore +++ b/.gitignore @@ -1,150 +1,11 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class +# dependencies +/node_modules/ -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ +# built dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - - -UI/public/.DS_Store -UI/.DS_Store - -*.DS_Store - -node_modules/ - -# model files on server -weights.bin -model.json - -# example training data -example_training_data/ -example_training_data.tar.gz - -# IDE files +# system specifics files .metals/ .idea/ .vscode/ +*.DS_Store diff --git a/DEV.md b/DEV.md index 657031a69..cf02d2b11 100644 --- a/DEV.md +++ b/DEV.md @@ -101,7 +101,7 @@ npm -ws run build **6.** Download and extract the sample training datasets. These datasets are used in the automated tests. ``` -./get_training_data.sh +./datasets/populate ``` **7.** Launch DISCO diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 56e606c39..5ffc5c22f 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -8,16 +8,8 @@ import { saveLog } from './utils' import { getTaskData } from './data' import { args } from './args' -const NUMBER_OF_USERS = args.numberOfUsers -const TASK = args.task - -const infoText = `\nStarted federated training of ${TASK.id}` -console.log(infoText) - -console.log({ args }) - async function runUser (task: Task, url: URL, data: data.DataSplit): Promise { - const client = new clients.federated.FederatedClient(url, task, new aggregators.MeanAggregator(TASK)) + const client = new clients.federated.FederatedClient(url, task, new aggregators.MeanAggregator()) // force the federated scheme const scheme = TrainingSchemes.FEDERATED @@ -28,17 +20,20 @@ async function runUser (task: Task, url: URL, data: data.DataSplit): Promise { +async function main (task: Task, numberOfUsers: number): Promise { + console.log(`Started federated training of ${task.id}`) + console.log({ args }) + const [server, url] = await startServer() - const data = await getTaskData(TASK) + const data = await getTaskData(task) const logs = await Promise.all( - Range(0, NUMBER_OF_USERS).map(async (_) => await runUser(TASK, url, data)).toArray() + Range(0, numberOfUsers).map(async (_) => await runUser(task, url, data)).toArray() ) if (args.save) { - const fileName = `${TASK.id}_${NUMBER_OF_USERS}users.csv` + const fileName = `${task.id}_${numberOfUsers}users.csv` saveLog(logs, fileName) } console.log('Shutting down the server...') @@ -48,4 +43,4 @@ async function main (): Promise { }) } -main().catch(console.error) +main(args.task, args.numberOfUsers).catch(console.error) diff --git a/cli/src/data.ts b/cli/src/data.ts index a2b2b2799..58013f4da 100644 --- a/cli/src/data.ts +++ b/cli/src/data.ts @@ -12,7 +12,7 @@ function filesFromFolder (dir: string, folder: string, fractionToKeep: number): } async function simplefaceData (task: Task): Promise { - const dir = '../example_training_data/simple_face/' + const dir = '../datasets/simple_face/' const youngFolders = ['child'] const oldFolders = ['adult'] @@ -39,7 +39,7 @@ async function simplefaceData (task: Task): Promise { } async function cifar10Data (cifar10: Task): Promise { - const dir = '../example_training_data/CIFAR10/' + const dir = '../datasets/CIFAR10/' const files = (await fs_promises.readdir(dir)).map((file) => path.join(dir, file)) const labels = Range(0, 24).map((label) => (label % 10).toString()).toArray() @@ -47,7 +47,7 @@ async function cifar10Data (cifar10: Task): Promise { } async function titanicData (titanic: Task): Promise { - const dir = '../example_training_data/titanic_train.csv' + const dir = '../datasets/titanic_train.csv' const data = await (new NodeTabularLoader(titanic, ',').loadAll( ['file://'.concat(dir)], diff --git a/datasets/.gitignore b/datasets/.gitignore new file mode 100644 index 000000000..98a59ef09 --- /dev/null +++ b/datasets/.gitignore @@ -0,0 +1,16 @@ +# example_training_data.tar.gz +/2_QAID_1.masked.reshaped.squared.224.png +/9-mnist-example.png +/CIFAR10/ +/cifar10-agents +/cifar10-example.png +/cifar10-labels.csv +/simple_face +/simple_face-example.png +/titanic_test.csv +/titanic_train.csv +/titanic_wrong_number_columns.csv +/titanic_wrong_passengerID.csv + +# wikitext +/wikitext/ diff --git a/datasets/README.md b/datasets/README.md new file mode 100644 index 000000000..e6d2b724d --- /dev/null +++ b/datasets/README.md @@ -0,0 +1,5 @@ +# Raw datasets + +This directory contains a selection of raw datasets. + +Run `./populate` to get all. diff --git a/datasets/populate b/datasets/populate new file mode 100755 index 000000000..49523e2cf --- /dev/null +++ b/datasets/populate @@ -0,0 +1,15 @@ +#!/bin/sh -eu + +# TODO replace by git submodules + +cd "$(dirname "$0")" + +# base +curl 'http://deai-313515.appspot.com.storage.googleapis.com/example_training_data.tar.gz' | + tar --extract --strip-components=1 + +# wikitext +mkdir -p wikitext +cd wikitext +curl 'https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz' | + tar --extract --gzip --strip-components=1 diff --git a/discojs/discojs-core/src/aggregator/base.ts b/discojs/discojs-core/src/aggregator/base.ts index db5b2e93b..2e9208e91 100644 --- a/discojs/discojs-core/src/aggregator/base.ts +++ b/discojs/discojs-core/src/aggregator/base.ts @@ -1,6 +1,6 @@ import { Map, Set } from 'immutable' -import type { client, Model, Task, AsyncInformant } from '..' +import type { client, Model, AsyncInformant } from '..' import { EventEmitter } from '../utils/event_emitter' @@ -54,10 +54,6 @@ export abstract class Base { protected _communicationRound = 0 constructor ( - /** - * The task for which the aggregator should be created. - */ - public readonly task: Task, /** * The Model whose weights are updated on aggregation. */ diff --git a/discojs/discojs-core/src/aggregator/get.ts b/discojs/discojs-core/src/aggregator/get.ts index a191ce828..739ac18e7 100644 --- a/discojs/discojs-core/src/aggregator/get.ts +++ b/discojs/discojs-core/src/aggregator/get.ts @@ -19,7 +19,7 @@ export function getAggregator (task: Task): aggregator.Aggregator { const error = new Error('not implemented') switch (task.trainingInformation.aggregator) { case AggregatorChoice.MEAN: - return new aggregator.MeanAggregator(task) + return new aggregator.MeanAggregator() case AggregatorChoice.ROBUST: throw error case AggregatorChoice.BANDIT: @@ -28,8 +28,8 @@ export function getAggregator (task: Task): aggregator.Aggregator { if (task.trainingInformation.scheme !== 'decentralized') { throw new Error('secure aggregation is currently supported for decentralized only') } - return new aggregator.SecureAggregator(task) + return new aggregator.SecureAggregator() default: - return new aggregator.MeanAggregator(task) + return new aggregator.MeanAggregator() } } diff --git a/discojs/discojs-core/src/aggregator/index.ts b/discojs/discojs-core/src/aggregator/index.ts index b522f26fa..3b2035afe 100644 --- a/discojs/discojs-core/src/aggregator/index.ts +++ b/discojs/discojs-core/src/aggregator/index.ts @@ -1,5 +1,5 @@ -import { type WeightsContainer } from '../weights' -import { type Base } from './base' +import type { WeightsContainer } from '../weights' +import type { Base } from './base' export { Base as AggregatorBase, AggregationStep } from './base' export { MeanAggregator } from './mean' diff --git a/discojs/discojs-core/src/aggregator/mean.spec.ts b/discojs/discojs-core/src/aggregator/mean.spec.ts index ffd00afb5..31905b233 100644 --- a/discojs/discojs-core/src/aggregator/mean.spec.ts +++ b/discojs/discojs-core/src/aggregator/mean.spec.ts @@ -1,11 +1,10 @@ import { assert, expect } from 'chai' import type { Map } from 'immutable' -import type { client, Model, Task } from '..' +import type { client, Model } from '..' import { aggregator, defaultTasks } from '..' import { AggregationStep } from './base' -const task = defaultTasks.titanic.getTask() const model = defaultTasks.titanic.getModel() const id = 'a' const weights = [1, 2, 3] @@ -14,12 +13,11 @@ const bufferCapacity = weights.length export class MockMeanAggregator extends aggregator.AggregatorBase { constructor ( - task: Task, model: Model, private readonly threshold: number, roundCutoff = 0 ) { - super(task, model, roundCutoff, 1) + super(model, roundCutoff, 1) } isFull (): boolean { @@ -56,36 +54,36 @@ export class MockMeanAggregator extends aggregator.AggregatorBase { describe('mean aggregator tests', () => { it('adding weight update with old time stamp returns false', async () => { const t0 = -1 - const aggregator = new MockMeanAggregator(task, await model, bufferCapacity) + const aggregator = new MockMeanAggregator(await model, bufferCapacity) assert.isFalse(aggregator.add(id, weights[0], t0)) }) it('adding weight update with recent time stamp returns true', async () => { - const aggregator = new MockMeanAggregator(task, await model, bufferCapacity) + const aggregator = new MockMeanAggregator(await model, bufferCapacity) const t0 = Date.now() assert.isTrue(aggregator.add(id, weights[0], t0)) }) it('aggregator returns false if it is not full', async () => { - const aggregator = new MockMeanAggregator(task, await model, bufferCapacity) + const aggregator = new MockMeanAggregator(await model, bufferCapacity) assert.isFalse(aggregator.isFull()) }) it('aggregator with standard cutoff = 0', async () => { - const aggregator = new MockMeanAggregator(task, await model, bufferCapacity) + const aggregator = new MockMeanAggregator(await model, bufferCapacity) assert.isTrue(aggregator.isWithinRoundCutoff(0)) assert.isFalse(aggregator.isWithinRoundCutoff(-1)) }) it('aggregator with different cutoff = 1', async () => { - const aggregator = new MockMeanAggregator(task, await model, bufferCapacity, 1) + const aggregator = new MockMeanAggregator(await model, bufferCapacity, 1) assert.isTrue(aggregator.isWithinRoundCutoff(0)) assert.isTrue(aggregator.isWithinRoundCutoff(-1)) assert.isFalse(aggregator.isWithinRoundCutoff(-2)) }) it('adding enough updates to buffer launches aggregator and updates weights', async () => { - const aggregator = new MockMeanAggregator(task, await model, bufferCapacity) + const aggregator = new MockMeanAggregator(await model, bufferCapacity) const mockAggregatedWeights = 2 const result = aggregator.receiveResult() @@ -98,7 +96,7 @@ describe('mean aggregator tests', () => { }) it('testing two full cycles (adding x2 buffer capacity)', async () => { - const aggregator = new MockMeanAggregator(task, await model, bufferCapacity, 0) + const aggregator = new MockMeanAggregator(await model, bufferCapacity, 0) let mockAggregatedWeights = 2 let result = aggregator.receiveResult() diff --git a/discojs/discojs-core/src/aggregator/mean.ts b/discojs/discojs-core/src/aggregator/mean.ts index 5d94baede..797759259 100644 --- a/discojs/discojs-core/src/aggregator/mean.ts +++ b/discojs/discojs-core/src/aggregator/mean.ts @@ -1,7 +1,7 @@ import type { Map } from 'immutable' import { AggregationStep, Base as Aggregator } from './base' -import type { Model, Task, WeightsContainer, client } from '..' +import type { Model, WeightsContainer, client } from '..' import { aggregation } from '..' /** @@ -16,12 +16,11 @@ export class MeanAggregator extends Aggregator { public readonly threshold: number constructor ( - task: Task, model?: Model, roundCutoff = 0, threshold = 1 ) { - super(task, model, roundCutoff, 1) + super(model, roundCutoff, 1) // Default threshold is 100% of node participation if (threshold === undefined) { diff --git a/discojs/discojs-core/src/aggregator/robust.ts b/discojs/discojs-core/src/aggregator/robust.ts index a5e4b26c4..a68e6f436 100644 --- a/discojs/discojs-core/src/aggregator/robust.ts +++ b/discojs/discojs-core/src/aggregator/robust.ts @@ -1,7 +1,7 @@ import { Base as Aggregator } from './base' -import { type client, type WeightsContainer } from '..' +import type { client, Model, WeightsContainer } from '..' -import { type Map } from 'immutable' +import type { Map } from 'immutable' export type Momentum = WeightsContainer @@ -11,6 +11,15 @@ export class RobustAggregator extends Aggregator { // TODO @s314y: move to task definition private readonly beta = 1 + constructor ( + private readonly tauPercentile: number, + model?: Model, + roundCutoff?: number, + communicationRounds?: number + ) { + super(model, roundCutoff, communicationRounds) + } + add (nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound: number): boolean { if (this.isWithinRoundCutoff(round)) { const stale = this.contributions.get(communicationRound) @@ -27,14 +36,7 @@ export class RobustAggregator extends Aggregator { } aggregate (): void { - if (this.task.trainingInformation.tauPercentile === undefined) { - throw new Error('task doesn\'t provide tau percentile') - } - // this.emit(aggregation.avgClippingWeights( - // this.contributions.values(), - // undefined as unknown as WeightsContainer, - // this.task.trainingInformation.tauPercentile - // )) + throw new Error('not implemented') } makePayloads (weights: WeightsContainer): Map { diff --git a/discojs/discojs-core/src/aggregator/secure.spec.ts b/discojs/discojs-core/src/aggregator/secure.spec.ts index 0f5f33797..0437b3794 100644 --- a/discojs/discojs-core/src/aggregator/secure.spec.ts +++ b/discojs/discojs-core/src/aggregator/secure.spec.ts @@ -1,11 +1,10 @@ import { List, Set, Range } from 'immutable' import { assert } from 'chai' -import { aggregator as aggregators, aggregation, WeightsContainer, defaultTasks } from '@epfml/discojs-core' +import { aggregator as aggregators, aggregation, WeightsContainer } from '@epfml/discojs-core' describe('secret shares test', function () { const epsilon = 1e-4 - const task = defaultTasks.cifar10.getTask() const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]) const secrets = List.of( @@ -17,7 +16,7 @@ describe('secret shares test', function () { function buildShares (): List> { const nodes = Set(secrets.keys()).map(String) return secrets.map((secret) => { - const aggregator = new aggregators.SecureAggregator(task) + const aggregator = new aggregators.SecureAggregator() aggregator.setNodes(nodes) return aggregator.generateAllShares(secret) }) diff --git a/discojs/discojs-core/src/aggregator/secure.ts b/discojs/discojs-core/src/aggregator/secure.ts index 5b8558580..7eb286a52 100644 --- a/discojs/discojs-core/src/aggregator/secure.ts +++ b/discojs/discojs-core/src/aggregator/secure.ts @@ -3,7 +3,7 @@ import { Map, List, Range } from 'immutable' import tf from '@tensorflow/tfjs' import { AggregationStep, Base as Aggregator } from './base' -import type { Model, Task, WeightsContainer, client } from '..' +import type { Model, WeightsContainer, client } from '..' import { aggregation } from '..' /** @@ -16,15 +16,11 @@ import { aggregation } from '..' export class SecureAggregator extends Aggregator { public static readonly MAX_SEED: number = 2 ** 47 - private readonly maxShareValue: number - constructor ( - task: Task, - model?: Model + model?: Model, + private readonly maxShareValue = 100 ) { - super(task, model, 0, 2) - - this.maxShareValue = this.task.trainingInformation.maxShareValue ?? 100 + super(model, 0, 2) } aggregate (): void { diff --git a/discojs/discojs-core/src/async_informant.ts b/discojs/discojs-core/src/async_informant.ts index c8902960f..530fc7d60 100644 --- a/discojs/discojs-core/src/async_informant.ts +++ b/discojs/discojs-core/src/async_informant.ts @@ -1,4 +1,4 @@ -import { type AggregatorBase } from './aggregator' +import type { AggregatorBase } from './aggregator' export class AsyncInformant { private _round = 0 @@ -11,8 +11,6 @@ export class AsyncInformant { ) {} update (): void { - console.debug('before:') - this.printAllInfos() if (this.round === 0 || this.round < this.aggregator.round) { this._round = this.aggregator.round this._currentNumberOfParticipants = this.aggregator.size @@ -21,8 +19,6 @@ export class AsyncInformant { } else { this._round = this.aggregator.round } - console.debug('after:') - this.printAllInfos() } // Getter functions @@ -52,13 +48,4 @@ export class AsyncInformant { averageNumberOfParticipants: this.averageNumberOfParticipants } } - - // Debug - public printAllInfos (): void { - console.debug('task:', this.aggregator.task.id) - console.debug('round:', this.round) - console.debug('participants:', this.currentNumberOfParticipants) - console.debug('total:', this.totalNumberOfParticipants) - console.debug('average:', this.averageNumberOfParticipants) - } } diff --git a/discojs/discojs-core/src/dataset/data/tabular_data.spec.ts b/discojs/discojs-core/src/dataset/data/tabular_data.spec.ts index 84a50ec24..3a87ea8b3 100644 --- a/discojs/discojs-core/src/dataset/data/tabular_data.spec.ts +++ b/discojs/discojs-core/src/dataset/data/tabular_data.spec.ts @@ -51,7 +51,7 @@ describe('tabular data checks', () => { it('throw an error on incorrectly formatted data', async () => { try { - await TabularData.init(tf.data.csv('file://../../example_training_data/cifar10-labels.csv', csvConfig), titanicMock, 3) + await TabularData.init(tf.data.csv('file://../../datasets/cifar10-labels.csv', csvConfig), titanicMock, 3) } catch (e) { expect(e).to.be.an.instanceOf(Error) return @@ -61,6 +61,6 @@ describe('tabular data checks', () => { }) it('do nothing on correctly formatted data', async () => { - await TabularData.init(tf.data.csv('file://../../example_training_data/titanic_train.csv', csvConfig), titanicMock, 3) + await TabularData.init(tf.data.csv('file://../../datasets/titanic_train.csv', csvConfig), titanicMock, 3) }) }) diff --git a/discojs/discojs-core/src/dataset/data_loader/data_loader.ts b/discojs/discojs-core/src/dataset/data_loader/data_loader.ts index 18f9bcc87..a4377d1e8 100644 --- a/discojs/discojs-core/src/dataset/data_loader/data_loader.ts +++ b/discojs/discojs-core/src/dataset/data_loader/data_loader.ts @@ -1,15 +1,8 @@ -import { type Task } from '../..' -import { type Dataset } from '../dataset' -import { type Data, type DataSplit } from '../data' +import type { DataSplit, Dataset } from '..' export interface DataConfig { features?: string[], labels?: string[], shuffle?: boolean, validationSplit?: number, inference?: boolean } export abstract class DataLoader { - constructor (protected task: Task) {} - - abstract createData (dataset: Dataset, size?: number): Promise - abstract load (source: Source, config: DataConfig): Promise - abstract loadAll (sources: Source[], config: DataConfig): Promise } diff --git a/discojs/discojs-core/src/dataset/data_loader/image_loader.ts b/discojs/discojs-core/src/dataset/data_loader/image_loader.ts index aa3384c1c..a42f5b0d4 100644 --- a/discojs/discojs-core/src/dataset/data_loader/image_loader.ts +++ b/discojs/discojs-core/src/dataset/data_loader/image_loader.ts @@ -1,8 +1,9 @@ import { Range } from 'immutable' import tf from '@tensorflow/tfjs' -import type { Dataset } from '../dataset' -import type { Data, DataSplit } from '../data' +import type { Task } from '../..' + +import type { Data, Dataset, DataSplit } from '..' import { ImageData } from '../data' import type { DataConfig } from '../data_loader' import { DataLoader } from '../data_loader' @@ -17,8 +18,10 @@ import { DataLoader } from '../data_loader' export abstract class ImageLoader extends DataLoader { abstract readImageFrom (source: Source): Promise - async createData (dataset: Dataset, size?: number): Promise { - return await ImageData.init(dataset, this.task, size) + constructor ( + private readonly task: Task + ) { + super() } async load (image: Source, config?: DataConfig): Promise { @@ -55,7 +58,7 @@ export abstract class ImageLoader extends DataLoader { // @ts-expect-error: For some reasons typescript refuses async generator but tensorflow do work with them const dataset: tf.data.Dataset = tf.data.generator(dataGenerator) - return await this.createData(dataset, indices.length) + return await ImageData.init(dataset, this.task, indices.length) } async loadAll (images: Source[], config?: DataConfig): Promise { diff --git a/discojs/discojs-core/src/dataset/data_loader/index.ts b/discojs/discojs-core/src/dataset/data_loader/index.ts index c5a886e44..39c0aa34d 100644 --- a/discojs/discojs-core/src/dataset/data_loader/index.ts +++ b/discojs/discojs-core/src/dataset/data_loader/index.ts @@ -1,4 +1,6 @@ -export { type DataConfig, DataLoader } from './data_loader' +export type { DataConfig } from './data_loader' +export { DataLoader } from './data_loader' + export { ImageLoader } from './image_loader' export { TabularLoader } from './tabular_loader' export { TextLoader } from './text_loader' diff --git a/discojs/discojs-core/src/dataset/data_loader/tabular_loader.ts b/discojs/discojs-core/src/dataset/data_loader/tabular_loader.ts index d3d738e54..41e912746 100644 --- a/discojs/discojs-core/src/dataset/data_loader/tabular_loader.ts +++ b/discojs/discojs-core/src/dataset/data_loader/tabular_loader.ts @@ -1,9 +1,11 @@ import { List, Map, Set } from 'immutable' -import { type Task } from '../..' -import { type Dataset } from '../dataset' -import { TabularData, type Data, type DataSplit } from '../data' -import { DataLoader, type DataConfig } from '../data_loader' +import type { Task } from '../..' + +import type { Dataset, DataSplit } from '..' +import { TabularData } from '..' +import type { DataConfig } from '../data_loader' +import { DataLoader } from '../data_loader' // Window size from which the dataset shuffling will sample const BUFFER_SIZE = 1000 @@ -14,8 +16,11 @@ const BUFFER_SIZE = 1000 * character-separated features and label(s). Such files typically have the .csv extension. */ export abstract class TabularLoader extends DataLoader { - constructor (task: Task, public readonly delimiter = ',') { - super(task) + constructor ( + private readonly task: Task, + public readonly delimiter = ',' + ) { + super() } /** @@ -26,12 +31,6 @@ export abstract class TabularLoader extends DataLoader { */ abstract loadDatasetFrom (source: Source, csvConfig: Record): Promise - async createData (dataset: Dataset): Promise { - // dataset.size does not work for csv datasets - // https://github.com/tensorflow/tfjs/issues/5845 - return await TabularData.init(dataset, this.task) - } - /** * Expects delimiter-separated tabular data made of N columns. The data may be * potentially split among several sources. Every source should contain N-1 @@ -90,7 +89,7 @@ export abstract class TabularLoader extends DataLoader { await this.load(source, { ...config, shuffle: false }))) let dataset = List(datasets).reduce((acc: Dataset, dataset) => acc.concatenate(dataset)) dataset = config?.shuffle === true ? dataset.shuffle(BUFFER_SIZE) : dataset - const data = await this.createData(dataset) + const data = await TabularData.init(dataset, this.task) // TODO: Implement validation split for tabular data (tricky due to streaming) return { train: data diff --git a/discojs/discojs-core/src/dataset/data_loader/text_loader.ts b/discojs/discojs-core/src/dataset/data_loader/text_loader.ts index 3c6ccc727..d1bbd633e 100644 --- a/discojs/discojs-core/src/dataset/data_loader/text_loader.ts +++ b/discojs/discojs-core/src/dataset/data_loader/text_loader.ts @@ -1,16 +1,34 @@ -import { TabularLoader } from './tabular_loader' -import { type Dataset } from '../dataset' -import { TextData, type Data } from '../data' +import type { Task } from '../..' + +import type { DataSplit, Dataset } from '..' +import { TextData } from '..' + +import { DataLoader } from '.' /** * Text data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely, - * @epfml/discojs-web and @epfml/discojs-node. Loads data from files whose entries are line-separated and each consist of - * a sentence-like sample associated to an optional label. + * @epfml/discojs-web and @epfml/discojs-node. */ -export abstract class TextLoader extends TabularLoader { - abstract loadDatasetFrom (source: Source, config: Record): Promise +export abstract class TextLoader extends DataLoader { + constructor ( + private readonly task: Task + ) { + super() + } + + abstract loadDatasetFrom (source: S): Promise + + async load (source: S): Promise { + return await this.loadDatasetFrom(source) + } + + async loadAll (sources: S[]): Promise { + const concatenated = + (await Promise.all(sources.map(async (src) => await this.load(src)))) + .reduce((acc, dataset) => acc.concatenate(dataset)) - async createData (dataset: Dataset): Promise { - return await TextData.init(dataset, this.task) + return { + train: await TextData.init(concatenated, this.task) + } } } diff --git a/discojs/discojs-core/src/default_tasks/index.ts b/discojs/discojs-core/src/default_tasks/index.ts index c9f113752..15730a99e 100644 --- a/discojs/discojs-core/src/default_tasks/index.ts +++ b/discojs/discojs-core/src/default_tasks/index.ts @@ -1,7 +1,8 @@ export { cifar10 } from './cifar10' +export { geotags } from './geotags' export { lusCovid } from './lus_covid' export { mnist } from './mnist' -export { titanic } from './titanic' export { simpleFace } from './simple_face' -export { geotags } from './geotags' export { skinMnist } from './skin_mnist' +export { titanic } from './titanic' +export { wikitext } from './wikitext' diff --git a/discojs/discojs-core/src/default_tasks/wikitext.ts b/discojs/discojs-core/src/default_tasks/wikitext.ts new file mode 100644 index 000000000..691a7d177 --- /dev/null +++ b/discojs/discojs-core/src/default_tasks/wikitext.ts @@ -0,0 +1,40 @@ +import type { Model, Task, TaskProvider } from '..' +import { TrainingSchemes, models } from '..' + +export const wikitext: TaskProvider = { + getTask (): Task { + return { + id: 'wikitext-103', + displayInformation: { + taskTitle: 'Language modelling on wikitext', + summary: { + preview: 'In this challenge, we ask you to do next word prediction on a dataset of Wikipedia articles.', + overview: 'Wikitext-103-raw is a dataset comprising unprocessed text excerpts from Wikipedia articles, designed for tasks related to natural language processing and language modeling.' + }, + limitations: 'The dataset may contain noise, inconsistencies, and unstructured content due to its raw nature, potentially posing challenges for certain NLP tasks.', + tradeoffs: 'The raw format may lack structured annotations and may require additional preprocessing for specific applications.', + dataFormatInformation: 'The dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.', + dataExampleText: 'An example excerpt from the dataset could be: "The history of artificial intelligence dates back to ancient times, with philosophical discussions on the nature of thought and reasoning."' + }, + trainingInformation: { + dataType: 'text', + modelID: 'wikitext-103-raw-model', + validationSplit: 0.2, // TODO: is this used somewhere? because train, eval and test are already split in dataset + epochs: 10, + // constructing a batch is taken care automatically in the dataset to make things faster + // so we fake a batch size of 1 + batchSize: 1, + scheme: TrainingSchemes.FEDERATED, + noiseScale: undefined, + decentralizedSecure: true, + minimumReadyPeers: 3, + maxShareValue: 100, + roundDuration: 10 + } + } + }, + + async getModel (): Promise { + return new models.GPT() + } +} diff --git a/discojs/discojs-core/src/informant/training_informant/base.ts b/discojs/discojs-core/src/informant/training_informant/base.ts index fa0cc9e25..b4e2a578d 100644 --- a/discojs/discojs-core/src/informant/training_informant/base.ts +++ b/discojs/discojs-core/src/informant/training_informant/base.ts @@ -11,6 +11,8 @@ export abstract class Base { protected readonly trainingGraphInformant = new GraphInformant() protected readonly validationGraphInformant = new GraphInformant() + private _losses = List() + // statistics protected currentRound = 0 protected currentNumberOfParticipants = 0 @@ -71,6 +73,20 @@ export abstract class Base { return this.validationGraphInformant.accuracy() } + set loss (loss: number | undefined) { + if (loss === undefined) throw new Error('loss is undefined') + this._losses = this._losses.push(loss) + } + + get loss (): number | undefined { + return this._losses.last() + } + + /** return loss of each round */ + get losses (): List { + return this._losses + } + trainingAccuracyData (): List { return this.trainingGraphInformant.data() } diff --git a/discojs/discojs-core/src/models/gpt/LICENSE.md b/discojs/discojs-core/src/models/gpt/LICENSE.md new file mode 100644 index 000000000..893968696 --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/LICENSE.md @@ -0,0 +1,23 @@ +MIT License + +Copyright (c) 2023 Nathan Maire +Copyright (c) 2023 lukemovement +Copyright (c) 2023 Anton Zemlyansky + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/discojs/discojs-core/src/models/gpt/config.ts b/discojs/discojs-core/src/models/gpt/config.ts new file mode 100644 index 000000000..b8e879dfb --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/config.ts @@ -0,0 +1,77 @@ +type ModelType = + | 'gpt2' + | 'gpt2-medium' + | 'gpt2-large' + | 'gpt2-xl' + | 'gpt-mini' + | 'gpt-micro' + | 'gpt-nano' + +interface ModelSize { + nLayer?: number + nHead?: number + nEmbd?: number +} + +export interface GPTConfig { + lr: number + batchSize: number + blockSize: number + vocabSize: number + evaluate?: boolean + maxEvalBatches?: number + evaluateEvery?: number + epochs?: number + maxIter?: number + weightDecay?: number + verbose?: 0 | 1 + bias?: boolean + debug?: boolean + dropout?: number + residDrop?: number + embdDrop?: number + tokEmb?: boolean + lmHead?: boolean + modelType: ModelType +} + +export const DEFAULT_CONFIG: Required = { + lr: 0.001, + weightDecay: 0, + batchSize: 2, + epochs: 9999, + maxIter: 10_000, + verbose: 0, + modelType: 'gpt-nano', + evaluate: true, + maxEvalBatches: 12, + evaluateEvery: 100, + blockSize: 128, + vocabSize: 50258, + bias: true, + debug: false, + dropout: 0.2, + residDrop: 0.2, + embdDrop: 0.2, + tokEmb: true, + lmHead: true +} + +export function getModelSizes (modelType: ModelType): Required { + switch (modelType) { + case 'gpt2': + return { nLayer: 12, nHead: 12, nEmbd: 768 } + case 'gpt2-medium': + return { nLayer: 24, nHead: 16, nEmbd: 1024 } + case 'gpt2-large': + return { nLayer: 36, nHead: 20, nEmbd: 1280 } + case 'gpt2-xl': + return { nLayer: 48, nHead: 25, nEmbd: 1600 } + case 'gpt-mini': + return { nLayer: 6, nHead: 6, nEmbd: 192 } + case 'gpt-micro': + return { nLayer: 4, nHead: 4, nEmbd: 128 } + case 'gpt-nano': + return { nLayer: 3, nHead: 3, nEmbd: 48 } + } +} diff --git a/discojs/discojs-core/src/models/gpt/evaluate.ts b/discojs/discojs-core/src/models/gpt/evaluate.ts new file mode 100644 index 000000000..cd07653b7 --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/evaluate.ts @@ -0,0 +1,54 @@ +import tf from '@tensorflow/tfjs' + +export default async function evaluate ( + model: tf.LayersModel, + dataset: tf.data.Dataset<{ xs: tf.Tensor, ys: tf.Tensor }> +): Promise> { + let datasetSize = 0 + let totalLoss = 0 + const acc: [number, number] = [0, 0] + + await dataset.map(({ xs, ys }) => { + const logits = model.apply(xs) + if (Array.isArray(logits)) { + throw new Error('model outputed many tensor') + } + if (logits instanceof tf.SymbolicTensor) { + throw new Error('model outputed symbolic tensor') + } + xs.dispose() + + return { logits, ys } + }).mapAsync(async ({ logits, ys }) => { + const loss = (await tf.losses.softmaxCrossEntropy(ys, logits).array()) + if (typeof loss !== 'number') { + throw new Error('got multiple loss') + } + + const accTensor = tf.metrics.categoricalAccuracy(ys, logits) + const accSize = accTensor.shape.reduce((l, r) => l * r, 1) + const accSum = accTensor.sum() + const accSummed = await accSum.array() + if (typeof accSummed !== 'number') { + throw new Error('got multiple accuracy sum') + } + + tf.dispose([ys, logits, accTensor, accSum]) + + return { loss, accSummed, accSize } + }).forEachAsync(({ loss, accSummed, accSize }) => { + datasetSize += 1 + totalLoss += loss + acc[0] += accSummed + acc[1] += accSize + }) + + const loss = totalLoss / datasetSize + + return { + val_loss: loss, + val_perplexity: Math.exp(loss), + acc: acc[0] / acc[1], + val_acc: acc[0] / acc[1] + } +} diff --git a/discojs/discojs-core/src/models/gpt/index.ts b/discojs/discojs-core/src/models/gpt/index.ts new file mode 100644 index 000000000..39552fc8c --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/index.ts @@ -0,0 +1,144 @@ +/** + * this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement + **/ + +import tf from '@tensorflow/tfjs' + +import { WeightsContainer } from '../..' +import type { Dataset } from '../../dataset' +import { Sink } from '../../utils/event_emitter' + +import type { EpochLogs, Prediction, Sample } from '../model' +import { Model } from '../model' + +import { GPTLMHeadModel } from './model' + +// TODO too big config +interface Config { + modelType: 'gpt-nano' + epochs: number // TODO mv to Task + maxIter: number + batchSize: number + blockSize: number + lr: number + vocabSize: number + maxEvalBatches: number +} + +export class GPT extends Model { + private readonly model: GPTLMHeadModel + + private static readonly batchSize = 4 + private static readonly blockSize = 128 + private static readonly vocabSize = 50258 + + constructor () { + super() + + // TODO sensible defaults? + const config: Config = { + modelType: 'gpt-nano', + epochs: 1, + maxIter: 2, + batchSize: GPT.batchSize, + blockSize: GPT.blockSize, + lr: 0.001, + vocabSize: GPT.vocabSize, + maxEvalBatches: 1 + } + + this.model = new GPTLMHeadModel(config) + } + + override get weights (): WeightsContainer { + return new WeightsContainer(this.model.weights.map((w) => w.read())) + } + + override set weights (ws: WeightsContainer) { + this.model.setWeights(ws.weights) + } + + // takes a stream of two bytes followed by a token ID + private convertCharDataset (dataset: Dataset): tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> { + const batchSize = 4 + const sampleSize = GPT.blockSize + 1 + const chunkSize = sampleSize * batchSize * 2 + + function toUInt16 (low: number, high: number): number { + low &= 0xff + high &= 0xff + return (high << 8) | low + } + + // TODO add support for small last batch + return dataset.batch(chunkSize, false).mapAsync(async (chunk) => { + if (!(chunk instanceof tf.Tensor)) { + throw new Error('chunk is not a Tensor') + } + if (chunk.shape.length !== 2 || chunk.shape[1] !== 1) { + throw new Error('dataset is not a only char') + } + + const buffer = await chunk.buffer() + + const xs = tf.buffer([batchSize, GPT.blockSize], 'int32') + const ys = tf.buffer([batchSize, GPT.blockSize, GPT.vocabSize], 'int32') + + for (let i = 0; i < batchSize; i++) { + for (let j = 0; j < sampleSize; j++) { + const idx = (i * sampleSize + j) * 2 + const low = buffer.get(idx) + const high = buffer.get(idx + 1) + const token = toUInt16(low, high) + if (j < sampleSize - 1) xs.set(token, i, j) + if (j > 0) ys.set(1, i, j - 1, token) + } + } + + return { xs: xs.toTensor(), ys: ys.toTensor() } + }) + } + + override async * train ( + trainingData: Dataset, + validationData?: Dataset, + epochs = 1, + tracker = new Sink() + ): AsyncGenerator { + for (let i = 0; i < epochs; i++) { + let logs: tf.Logs | undefined + + await this.model.fitDataset( + this.convertCharDataset(trainingData), { + epochs: 1, + validationData: validationData !== undefined ? this.convertCharDataset(validationData) : validationData, + callbacks: { + onEpochEnd: (_, cur) => { logs = cur }, + onBatchBegin: () => { tracker.emit('batchBegin', undefined) }, + onBatchEnd: () => { tracker.emit('batchEnd', undefined) } + } + }) + + yield logs + } + } + + override async predict (input: Sample): Promise { + const ret = this.model.predict(input) + if (Array.isArray(ret)) { + throw new Error('prediction yield many Tensors but should have only returned one') + } + + return ret + } + + static deserialize (weights: WeightsContainer): Model { + const model = new GPT() + model.weights = weights + return model + } + + serialize (): WeightsContainer { + return this.weights + } +} diff --git a/discojs/discojs-core/src/models/gpt/model.ts b/discojs/discojs-core/src/models/gpt/model.ts new file mode 100644 index 000000000..f881c50c6 --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/model.ts @@ -0,0 +1,542 @@ +import tf, { LayersModel, layers, serialization } from '@tensorflow/tfjs' + +import type { GPTConfig } from './config' +import { getModelSizes, DEFAULT_CONFIG } from './config' +import { train } from './train' +import type { TrainingCallbacks } from './types' + +class Range extends layers.Layer { + static readonly className = 'Range' + + computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + return inputShape + } + + call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + return tf.tidy(() => { + if (Array.isArray(input)) { + // TODO support multitensor + input = input[0] + } + this.invokeCallHook(input, kwargs) + const T = input.shape[1] + if (T === undefined) throw new Error('unexpected shape') + return tf.reshape(tf.range(0, T, 1, 'int32'), [1, T]) + }) + } +} +serialization.registerClass(Range) + +class LogLayer extends layers.Layer { + static readonly className = 'LogLayer' + + computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + return inputShape + } + + call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + return tf.tidy(() => { + if (Array.isArray(input)) { + input = input[0] + } + this.invokeCallHook(input, kwargs) + return input + }) + } +} +serialization.registerClass(LogLayer) + +class CausalSelfAttentionBase extends layers.Layer { + static readonly className = 'CausalSelfAttentionBase' + + private readonly blockSize: number + private readonly nHead: number + private readonly nEmbd: number + private readonly dropout: number + private readonly mask: tf.Tensor + + constructor ( + private readonly config: ConstructorParameters[0] & Record<'blockSize' | 'nHead' | 'nEmbd' | 'dropout', number> + ) { + super(config) + + this.blockSize = config.blockSize + this.nEmbd = config.nEmbd + this.nHead = config.nHead + this.dropout = config.dropout + + this.mask = tf.linalg.bandPart(tf.ones([this.blockSize, this.blockSize]), -1, 0) + } + + computeOutputShape (): tf.Shape | tf.Shape[] { + // TODO doesn't take input shape in account + return [null, this.blockSize, this.nEmbd] + } + + getConfig (): serialization.ConfigDict { + const config = super.getConfig() + return Object.assign({}, config, this.config) + } + + call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + return tf.tidy(() => { + if (Array.isArray(input)) { + input = input[0] + } + this.invokeCallHook(input, kwargs) + + let [q, k, v] = input.split(3, -1) as [tf.Tensor, tf.Tensor, tf.Tensor] + const [B, T, C] = k.shape + const splitHeads = (x: tf.Tensor): tf.Tensor4D => + x.reshape([B, T, this.nHead, C / this.nHead]).transpose([0, 2, 1, 3]) + q = splitHeads(q) + k = splitHeads(k) + v = splitHeads(v) + + let att = tf.mul( + tf.matMul(q, k, false, true), + tf.div( + 1, + tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32')) + ) + ) + att = tf.add(att, tf.mul(tf.sub(1, this.mask), -1e9)) + att = tf.softmax(att, -1) + att = kwargs.training === true ? tf.dropout(att, this.dropout) : att + + let y = tf.matMul(att, v) + y = tf.transpose(y, [0, 2, 1, 3]) + y = tf.reshape(y, [B, T, C]) + + return y + }) + } +} +serialization.registerClass(CausalSelfAttentionBase) + +type CausalSelfAttentionConfig = + ConstructorParameters[0] + & Record<'blockSize' | 'nHead' | 'nEmbd' | 'dropout', number> + & { bias: boolean } + +class CausalSelfAttention extends layers.Layer { + static readonly className = 'CausalSelfAttention' + + private readonly nHead: number + private readonly nEmbd: number + private readonly dropout: number + private readonly bias: boolean + private readonly mask: tf.Tensor2D + + cAttnKernel?: tf.LayerVariable + cAttnBias?: tf.LayerVariable + cProjKernel?: tf.LayerVariable + cProjBias?: tf.LayerVariable + + constructor (private readonly config: CausalSelfAttentionConfig) { + super(config) + + this.nEmbd = config.nEmbd + this.nHead = config.nHead + this.dropout = config.dropout + this.bias = config.bias + + this.mask = tf.linalg.bandPart(tf.ones([config.blockSize, config.blockSize]), -1, 0) + } + + build (): void { + this.cAttnKernel = this.addWeight( + 'c_attn/kernel', + [this.nEmbd, 3 * this.nEmbd], + 'float32', + tf.initializers.glorotNormal({}) + ) + this.cAttnBias = this.addWeight( + 'c_attn/bias', + [3 * this.nEmbd], + 'float32', + tf.initializers.zeros() + ) + this.cProjKernel = this.addWeight( + 'c_proj/kernel', + [this.nEmbd, this.nEmbd], + 'float32', + tf.initializers.glorotNormal({}) + ) + this.cProjBias = this.addWeight( + 'c_proj/bias', + [this.nEmbd], + 'float32', + tf.initializers.zeros() + ) + } + + computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + return inputShape + } + + getConfig (): serialization.ConfigDict { + const config = super.getConfig() + return Object.assign({}, config, this.config) + } + + call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + return tf.tidy(() => { + if (this.cAttnKernel === undefined || + this.cAttnBias === undefined || + this.cProjKernel === undefined || + this.cProjBias === undefined + ) { throw new Error('not built') } + + if (Array.isArray(input)) { + input = input[0] + } + this.invokeCallHook(input, kwargs) + + const dense = (x: tf.Tensor, kernel: tf.LayerVariable, bias: tf.LayerVariable): tf.Tensor => { + const k = kernel.read().expandDims(0).tile([x.shape[0], 1, 1]) + const m = x.matMul(k) + if (this.bias) { + return tf.add(m, bias.read()) + } else { + return m + } + } + + const cAttn = dense(input, this.cAttnKernel, this.cAttnBias) + + let [q, k, v] = tf.split(cAttn, 3, -1) as [tf.Tensor, tf.Tensor, tf.Tensor] + const [B, T, C] = k.shape + + const splitHeads = (x: tf.Tensor): tf.Tensor => + tf.transpose( + tf.reshape(x, [B, T, this.nHead, C / this.nHead]), + [0, 2, 1, 3] + ) + + q = splitHeads(q) + k = splitHeads(k) + v = splitHeads(v) + + let att = tf.mul( + tf.matMul(q, k, false, true), + tf.div( + 1, + tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32')) + ) + ) + + const mask = this.mask.slice([0, 0], [T, T]) + att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9)) + att = tf.softmax(att, -1) + att = kwargs.training === true ? tf.dropout(att, this.dropout) : att + + let y = tf.matMul(att, v) + y = tf.transpose(y, [0, 2, 1, 3]) + y = tf.reshape(y, [B, T, C]) + y = dense(y, this.cProjKernel, this.cProjBias) + y = kwargs.training === true ? tf.dropout(y, this.dropout) : y + + return y + }) + } +} +serialization.registerClass(CausalSelfAttention) + +class GELU extends layers.Layer { + static readonly className = 'GELU' + + constructor () { + super({}) + } + + computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + return inputShape + } + + call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + return tf.tidy(() => { + if (Array.isArray(input)) { + // TODO support multitensor + input = input[0] + } + this.invokeCallHook(input, kwargs) + const cdf = tf.mul( + 0.5, + tf.add( + 1, + tf.tanh( + tf.mul( + tf.sqrt(tf.div(2, Math.PI)), + tf.add(input, tf.mul(0.044715, tf.pow(input, 3))) + ) + ) + ) + ) + return tf.mul(input, cdf) + }) + } +} +serialization.registerClass(GELU) + +function MLP (conf: any): LayersModel { + const config = Object.assign({ name: 'mlp' }, conf) + const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] }) + let x + x = tf.layers + .dense({ + name: config.name + '/c_fc', + units: 4 * config.nEmbd, + inputDim: config.nEmbd, + inputShape: [config.blockSize, config.nEmbd] + }) + .apply(inputs) + x = new GELU().apply(x) + x = tf.layers + .dense({ + name: config.name + '/c_proj', + units: config.nEmbd, + inputDim: 4 * config.nEmbd, + inputShape: [config.blockSize, 4 * config.nEmbd] + }) + .apply(x) + x = tf.layers + .dropout({ + name: config.name + '/drop', + rate: config.residDrop + }) + .apply(x) + return tf.model({ inputs, outputs: x as any }) +} + +function Block (conf: CausalSelfAttentionConfig & { debug: boolean }): LayersModel { + const config = Object.assign({ name: 'h' }, conf) + const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] }) + let x1, x2 + x1 = tf.layers + .layerNormalization({ name: config.name + '/ln_1', epsilon: 1e-5 }) + .apply(inputs) + if (config.debug) { + x1 = new LogLayer({ name: config.name + '/ln_1_log' }).apply(x1) + } + x1 = new CausalSelfAttention( + Object.assign({}, config, { name: config.name + '/attn' }) + ).apply(x1) + x1 = tf.layers.add().apply([inputs, x1 as any]) + x2 = tf.layers + .layerNormalization({ name: config.name + '/ln_2', epsilon: 1e-5 }) + .apply(x1) + x2 = MLP(Object.assign({}, config, { name: config.name + '/mlp' })).apply( + x2 + ) + x2 = tf.layers.add().apply([x1 as any, x2 as any]) + return tf.model({ name: config.name, inputs, outputs: x2 as any }) +} + +function GPT (conf: GPTConfig): LayersModel { + const configDefaults = { + name: 'transformer', + ...DEFAULT_CONFIG + } + + const modelSizes = getModelSizes(conf.modelType) + const config = Object.assign({}, configDefaults, conf, modelSizes) + + console.log('IN MODEL CONFIG', config) + + const inputs = tf.input({ shape: [null] }) + + const tokEmb = config.tokEmb + ? tf.layers + .embedding({ + name: config.name + '/wte', + inputDim: config.vocabSize, + outputDim: config.nEmbd, + embeddingsInitializer: 'zeros', + embeddingsRegularizer: undefined, + activityRegularizer: undefined + }) + .apply(inputs) + : inputs + + const range = new Range({}).apply(inputs) + let posEmb = tf.layers + .embedding({ + name: config.name + '/wpe', + inputDim: config.blockSize, + outputDim: config.nEmbd, + embeddingsInitializer: 'zeros' + }) + .apply(range) + if (config.debug) { + posEmb = new LogLayer({ name: 'posEmb' }).apply(posEmb) + } + + let x + x = tf.layers.add().apply([tokEmb as any, posEmb as any]) + x = tf.layers + .dropout({ + name: 'drop', + rate: config.embdDrop + }) + .apply(x) + if (config.debug) { + x = new LogLayer({ name: 'dropadd' }).apply(x) + } + + for (let i = 0; i < config.nLayer; i++) { + x = Block( + Object.assign({}, config, { name: config.name + '/h/' + i }) + ).apply(x) + } + x = tf.layers + .layerNormalization({ name: config.name + '/ln_f', epsilon: 1e-5 }) + .apply(x) + if (config.debug) { + x = new LogLayer({ name: 'fin/ln' }).apply(x) + } + + if (config.lmHead) { + x = tf.layers + .dense({ + name: 'lm_head', + units: config.vocabSize, + inputDim: config.nEmbd, + inputShape: [config.blockSize, config.nEmbd], + useBias: false + }) + .apply(x) + } + return tf.model({ inputs, outputs: x as any }) +} + +interface GenerateConfig { + maxNewTokens: number + temperature: number + doSample: boolean +} + +const defaultGenerateConfig: GenerateConfig = { + maxNewTokens: 20, + temperature: 1.0, + doSample: false +} + +function prepareIdx (idx: tf.TensorLike): tf.Tensor2D { + return tf.tidy(() => { + let ret: tf.Tensor + if (idx instanceof tf.Tensor) { + ret = idx.clone() + } else { + ret = tf.tensor(idx) + } + if (ret.dtype !== 'int32') { + ret = ret.toInt() + } + switch (ret.shape.length) { + case 1: + return ret.expandDims(0) + case 2: + return ret as tf.Tensor2D + default: + throw new Error('unexpected shape') + } + }) +} + +/** + * tfjs does not export LazyIterator and Dataset... + */ +declare abstract class LazyIterator { + abstract next (): Promise> +} + +declare abstract class Dataset { + abstract iterator (): Promise> + size: number +} + +class GPTModel extends LayersModel { + constructor (protected readonly config: GPTConfig) { + const gpt = GPT(config) + const { inputs, outputs, name } = gpt + super({ inputs, outputs, name }) + Object.assign(this, gpt) + } + + async fitDataset ( + dataset: Dataset, + args: tf.ModelFitDatasetArgs + ): Promise { + console.log('=== GPTModel custom train function ===') + const config = { ...this.config, ...args } + + await train( + this, + dataset as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>, + config, + args.callbacks as TrainingCallbacks, + args.validationData as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> + ) + + return new tf.History() + } +} + +interface GenerateOutput { idxNext: tf.Tensor2D, timePerToken: number } + +class GPTLMHeadModel extends GPTModel { + async generate (idxRaw: tf.TensorLike, conf: GenerateConfig, act?: (_: GenerateOutput) => Promise): Promise { + const config = Object.assign({}, defaultGenerateConfig, conf) + let idx = prepareIdx(idxRaw) + for (let step = 0; step < config.maxNewTokens; step++) { + const { idxNext, timePerToken } = this.generateOnce(this, idx, config) + const idxNew = idx.concat(idxNext, 1) + tf.dispose(idx) + idx = idxNew + const idxNextArr = await (idxNext as any).array() + tf.dispose(idxNext) + if (act !== undefined) { + await act({ idxNext: idxNextArr, timePerToken }) + } + } + const idxArr = await idx.array() + tf.dispose(idx) + return idxArr + } + + private generateOnce (model: tf.LayersModel, idx: tf.Tensor2D, config: GenerateConfig): GenerateOutput { + let timePerToken = performance.now() + + const idxNext = tf.tidy(() => { + const blockSize: any = model.inputs[0].shape[1] + const idxCond = + idx.shape[1] <= blockSize + ? idx + : idx.slice([0, -blockSize], [-1, -1]) + const outputed = model.predict(idxCond) + if (Array.isArray(outputed)) throw new Error('model outputed multiple values') + if (outputed.shape.length !== 3) throw new Error('model outputed weird shape') + const logits = outputed as tf.Tensor3D + + timePerToken = performance.now() - timePerToken + const logitsScaled = logits + .slice([0, idx.shape[1] - 1, 0]) + .reshape([logits.shape[0], logits.shape[2]]) + .div(tf.scalar(config.temperature)) + const probs = logitsScaled.softmax(-1) + if (config.doSample) { + return tf.multinomial(probs, 1) as tf.Tensor2D + } else { + return probs.argMax(-1).expandDims(1) + } + }) + + return { + idxNext, + timePerToken + } + } +} + +export { GPT, GPTModel, GPTLMHeadModel } diff --git a/discojs/discojs-core/src/models/gpt/optimizers.ts b/discojs/discojs-core/src/models/gpt/optimizers.ts new file mode 100644 index 000000000..7c8ee03ea --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/optimizers.ts @@ -0,0 +1,118 @@ +import tf, { AdamOptimizer } from '@tensorflow/tfjs' + +function l2Loss (tensor: tf.Tensor): tf.Tensor { + return tf.div(tf.sum(tf.square(tensor)), 2) +} + +function globalNorm (tensors: tf.Tensor[]): tf.Tensor { + const halfSquaredNorms: tf.Tensor[] = [] + tensors.forEach((tensor: tf.Tensor) => { + halfSquaredNorms.push(l2Loss(tensor)) + }) + const halfSquaredNorm: tf.Tensor = tf.sum(tf.stack(halfSquaredNorms)) + const norm: tf.Tensor = tf.sqrt( + tf.mul(halfSquaredNorm, tf.scalar(2.0, halfSquaredNorm.dtype)) + ) + return norm +} + +function clipByGlobalNorm ( + tensors: tf.Tensor[], + clipNorm: number, + useNorm?: tf.Tensor +): tf.Tensor[] { + return tf.tidy(() => { + useNorm = useNorm ?? globalNorm(tensors) + const scale: tf.Tensor = tf.mul( + clipNorm, + tf.minimum( + tf.div(tf.scalar(1.0), useNorm), + tf.div(tf.scalar(1.0, useNorm.dtype), clipNorm) + ) + ) + const tensorsClipped: tf.Tensor[] = [] + tensors.forEach((tensor: tf.Tensor) => { + tensorsClipped.push(tf.clone(tf.mul(tensor, scale))) + }) + return tensorsClipped + }) +} + +function clipByGlobalNormObj ( + tensorsObj: Record, + clipNorm: number, + useNorm?: tf.Tensor +): Record { + const varNames: string[] = Object.keys(tensorsObj) + const tensorsArr: tf.Tensor[] = varNames.map((n: string) => tensorsObj[n]) + const tensorsArrClipped: tf.Tensor[] = clipByGlobalNorm( + tensorsArr, + clipNorm, + useNorm + ) + const tensorsObjClipped: Record = {} + tensorsArrClipped.forEach((t: tf.Tensor, ti: number) => { + tensorsObjClipped[varNames[ti]] = t + }) + return tensorsObjClipped +} + +class AdamW extends AdamOptimizer { + weightDecayRate: number + includeInWeightDecay: string[] + excludeFromWeightDecay: string[] + gradientClipNorm: number + + constructor (params: { + learningRate?: number + beta1?: number + beta2?: number + epsilon?: number + weightDecayRate?: number + includeInWeightDecay?: string[] + excludeFromWeightDecay?: string[] + gradientClipNorm?: number + }) { + console.log('Using custom AdamW optimizer') + const defaultParams = { + learningRate: 0.1, + beta1: 0.9, + beta2: 0.999, + epsilon: 1e-7, + weightDecayRate: 0, + includeInWeightDecay: [], + excludeFromWeightDecay: [], + gradientClipNorm: 1.0 + } + const p = Object.assign({}, defaultParams, params) + super(p.learningRate, p.beta1, p.beta2, p.epsilon) + this.weightDecayRate = p.weightDecayRate + this.includeInWeightDecay = p.includeInWeightDecay + this.excludeFromWeightDecay = p.excludeFromWeightDecay + this.gradientClipNorm = p.gradientClipNorm + } + + applyGradients (variableGradients: Record | Array<{ name: string, tensor: tf.Tensor }>): void { + const varNames: string[] = Array.isArray(variableGradients) + ? variableGradients.map((v) => v.name) + : Object.keys(variableGradients) + + varNames.forEach((name: string) => { + if (this.includeInWeightDecay.includes(name)) { + const value = tf.engine().registeredVariables[name] + const newValue: tf.Tensor = tf.sub( + value, + tf.mul( + this.learningRate, + tf.mul(value, this.weightDecayRate) + ) + ) + value.assign(newValue) + } + }) + + super.applyGradients(variableGradients) + } +} + +export { AdamW, clipByGlobalNorm, clipByGlobalNormObj } diff --git a/discojs/discojs-core/src/models/gpt/train.ts b/discojs/discojs-core/src/models/gpt/train.ts new file mode 100644 index 000000000..3d1c92649 --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/train.ts @@ -0,0 +1,115 @@ +import tf from '@tensorflow/tfjs' + +import { AdamW, clipByGlobalNormObj } from './optimizers' +import type { GPTConfig } from './config' +import { DEFAULT_CONFIG } from './config' +import evaluate from './evaluate' +import type { TrainingCallbacks } from './types' + +function resolveConfig (config: GPTConfig): Required { + return { + ...DEFAULT_CONFIG, + ...config + } +} + +function getCustomAdam (model: tf.LayersModel, c: Required): tf.Optimizer { + const includeInWeightDecay: string[] = [] + const excludeFromWeightDecay: string[] = [] + + // TODO unsafe cast + const namedWeights = (model as unknown as any).getNamedWeights() as Array<{ name: string, tensor: tf.Tensor }> + + namedWeights.forEach((v) => { + if ( + v.name.includes('bias') || + v.name.includes('normalization') || + v.name.includes('emb') + ) { + excludeFromWeightDecay.push(v.name) + } else { + includeInWeightDecay.push(v.name) + } + }) + + return new AdamW({ + learningRate: c.lr, + weightDecayRate: c.weightDecay, + includeInWeightDecay, + excludeFromWeightDecay + }) +} + +export async function train ( + model: tf.LayersModel, + ds: tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>, + config: GPTConfig, + callbacks: TrainingCallbacks, + evalDs?: tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> +): Promise { + const c = resolveConfig(config) + + const opt = c.weightDecay !== 0 ? getCustomAdam(model, c) : tf.train.adam(c.lr) + + await callbacks.onTrainBegin?.() + + console.warn('=== Starting training ===') + + for (let epoch = 1; epoch <= c.epochs; epoch++) { + await callbacks.onEpochBegin?.(epoch) + + await tf.data.zip<[number, { xs: tf.Tensor2D, ys: tf.Tensor3D }]>([ + tf.data.generator(function * () { + for (let i = 1; i <= c.maxIter; i++) { yield i } + }), + ds + ]).mapAsync(async ([iteration, { xs, ys }]) => { + await callbacks.onBatchBegin?.(iteration) + return { iteration, xs, ys } + }).map(({ iteration, xs, ys }) => tf.tidy(() => { + const { grads, value: loss } = opt.computeGradients(() => { + const logits = model.apply(xs) + if (Array.isArray(logits)) { + throw new Error('model outputed many tensor') + } + if (logits instanceof tf.SymbolicTensor) { + throw new Error('model outputed symbolic tensor') + } + + const loss = tf.losses.softmaxCrossEntropy(ys, logits) + return loss as tf.Scalar + }) + + tf.dispose([xs, ys]) + + const gradsClipped = clipByGlobalNormObj(grads, 1) + opt.applyGradients(gradsClipped) + + return { iteration, loss } + })).mapAsync(async ({ iteration, loss }) => { + const raw = await loss.array() + tf.dispose(loss) + return [iteration, raw] + }).mapAsync(async ([iteration, loss]) => { + await callbacks.onBatchEnd?.(iteration) + return [iteration, loss] + }).forEachAsync(([iteration, loss]) => { + console.log( + `Epoch: ${epoch}`, + `\tStep: ${iteration} / ${c.maxIter}`, + `\tLoss: ${loss.toFixed(3)}`, + `\tMemory: ${(tf.memory().numBytes / 1024 / 1024).toFixed(2)} MB` + ) + }) + + let logs: tf.Logs | undefined + if (evalDs !== undefined) { + logs = await evaluate(model, evalDs) + } + await callbacks.onEpochEnd?.(epoch, logs) + + await new Promise((resolve) => setTimeout(resolve, 1)) + } + + await callbacks.onTrainEnd?.() +} diff --git a/discojs/discojs-core/src/models/gpt/types.ts b/discojs/discojs-core/src/models/gpt/types.ts new file mode 100644 index 000000000..ed40e168d --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/types.ts @@ -0,0 +1,10 @@ +import type tf from '@tensorflow/tfjs' + +export interface TrainingCallbacks { + onEpochBegin?: (epoch: number, logs?: tf.Logs) => Promise + onEpochEnd?: (epoch: number, logs?: tf.Logs) => Promise + onBatchBegin?: (batch: number, logs?: tf.Logs) => Promise + onBatchEnd?: (batch: number, logs?: tf.Logs) => Promise + onTrainBegin?: (logs?: tf.Logs) => Promise + onTrainEnd?: (logs?: tf.Logs) => Promise +} diff --git a/discojs/discojs-core/src/models/index.ts b/discojs/discojs-core/src/models/index.ts index 25b868724..e6bf727ff 100644 --- a/discojs/discojs-core/src/models/index.ts +++ b/discojs/discojs-core/src/models/index.ts @@ -1,2 +1,4 @@ export { Model } from './model' + +export { GPT } from './gpt' export { TFJS } from './tfjs' diff --git a/discojs/discojs-core/src/serialization/model.ts b/discojs/discojs-core/src/serialization/model.ts index fa4dcdd59..d41585490 100644 --- a/discojs/discojs-core/src/serialization/model.ts +++ b/discojs/discojs-core/src/serialization/model.ts @@ -2,7 +2,9 @@ import msgpack from 'msgpack-lite' import type tf from '@tensorflow/tfjs' import type { Model } from '..' -import { models } from '..' +import { models, serialization } from '..' + +const enum Type { TFJS, GPT } export type Encoded = Uint8Array @@ -13,7 +15,12 @@ export function isEncoded (raw: unknown): raw is Encoded { export async function encode (model: Model): Promise { if (model instanceof models.TFJS) { const serialized = await model.serialize() - return msgpack.encode(serialized) + return msgpack.encode([Type.TFJS, serialized]) + } + + if (model instanceof models.GPT) { + const serialized = await serialization.weights.encode(model.serialize()) + return msgpack.encode([Type.GPT, serialized]) } throw new Error('unknown model type') @@ -23,9 +30,32 @@ export async function decode (encoded: unknown): Promise { if (!isEncoded(encoded)) { throw new Error('invalid encoding') } - const raw = msgpack.decode(encoded) + const raw: unknown = msgpack.decode(encoded) + + if (!Array.isArray(raw) || raw.length !== 2) { + throw new Error('invalid encoding') + } + const [type, rawModel] = raw as [unknown, unknown] - // TODO how to select model type? prepend with model id - // TODO totally unsafe casting - return await models.TFJS.deserialize(raw as tf.io.ModelArtifacts) + if (typeof type !== 'number' || (type !== Type.TFJS && type !== Type.GPT)) { + throw new Error('invalid encoding') + } + switch (type) { + case Type.TFJS: + // TODO totally unsafe casting + return await models.TFJS.deserialize(rawModel as tf.io.ModelArtifacts) + case Type.GPT: { + if (!Array.isArray(rawModel)) { + throw new Error('invalid encoding') + } + const arr: unknown[] = rawModel + if (arr.some((r) => typeof r !== 'number')) { + throw new Error('invalid encoding') + } + const nums = arr as number[] + + const serialized = serialization.weights.decode(nums) + return models.GPT.deserialize(serialized) + } + } } diff --git a/discojs/discojs-core/src/training/disco.ts b/discojs/discojs-core/src/training/disco.ts index c9d33cbb5..227582258 100644 --- a/discojs/discojs-core/src/training/disco.ts +++ b/discojs/discojs-core/src/training/disco.ts @@ -45,7 +45,7 @@ export class Disco { options.scheme = TrainingSchemes[task.trainingInformation.scheme as keyof typeof TrainingSchemes] } if (options.aggregator === undefined) { - options.aggregator = new MeanAggregator(task) + options.aggregator = new MeanAggregator() } if (options.client === undefined) { if (options.url === undefined) { diff --git a/discojs/discojs-core/src/training/trainer/trainer.ts b/discojs/discojs-core/src/training/trainer/trainer.ts index eea1d866f..c9b20d784 100644 --- a/discojs/discojs-core/src/training/trainer/trainer.ts +++ b/discojs/discojs-core/src/training/trainer/trainer.ts @@ -80,6 +80,9 @@ export abstract class Trainer { if (logs !== undefined && !isNaN(logs.acc) && !isNaN(logs.val_acc)) { this.trainingInformant.updateTrainingGraph(this.roundDecimals(logs.acc)) this.trainingInformant.updateValidationGraph(this.roundDecimals(logs.val_acc)) + if (logs.val_loss !== undefined) { + this.trainingInformant.loss = logs.val_loss + } } else { this.trainerLogger.error('onEpochEnd: NaN value') } diff --git a/discojs/discojs-core/src/validation/validator.spec.ts b/discojs/discojs-core/src/validation/validator.spec.ts index f13218a03..63520a88e 100644 --- a/discojs/discojs-core/src/validation/validator.spec.ts +++ b/discojs/discojs-core/src/validation/validator.spec.ts @@ -24,7 +24,7 @@ const simplefaceMock: Task = { describe('validator', () => { it('simple_face validator', async () => { - const dir = '../../example_training_data/simple_face/' + const dir = '../../datasets/simple_face/' const files: string[][] = ['child/', 'adult/'] .map((subdir: string) => fs.readdirSync(dir + subdir) .map((file: string) => dir + subdir + file)) @@ -32,7 +32,7 @@ describe('validator', () => { const data = (await new NodeImageLoader(simplefaceMock) .loadAll(files.flat(), { labels })).train - const meanAggregator = new aggregator.MeanAggregator(simplefaceMock) + const meanAggregator = new aggregator.MeanAggregator() const client = new clients.Local(new URL('http://localhost:8080'), simplefaceMock, meanAggregator) meanAggregator.setModel(await client.getLatestModel()) const validator = new Validator( @@ -60,13 +60,13 @@ describe('validator', () => { it('titanic validator', async () => { const titanicTask = defaultTasks.titanic.getTask() - const files = ['../../example_training_data/titanic_train.csv'] + const files = ['../../datasets/titanic_train.csv'] const data: data.Data = (await new NodeTabularLoader(titanicTask, ',').loadAll(files, { features: titanicTask.trainingInformation.inputColumns, labels: titanicTask.trainingInformation.outputColumns, shuffle: false })).train - const meanAggregator = new aggregator.MeanAggregator(titanicTask) + const meanAggregator = new aggregator.MeanAggregator() const client = new clients.Local(new URL('http://localhost:8080'), titanicTask, meanAggregator) meanAggregator.setModel(await client.getLatestModel()) const validator = new Validator(titanicTask, new ConsoleLogger(), new EmptyMemory(), undefined, client) diff --git a/discojs/discojs-node/src/data/image_loader.spec.ts b/discojs/discojs-node/src/data/image_loader.spec.ts index cbe1c6de4..4a5369959 100644 --- a/discojs/discojs-node/src/data/image_loader.spec.ts +++ b/discojs/discojs-node/src/data/image_loader.spec.ts @@ -12,7 +12,7 @@ const readFilesFromDir = (dir: string): string[] => fs.readdirSync(dir).map((file: string) => dir + file) const DIRS = { - CIFAR10: '../../example_training_data/CIFAR10/' + CIFAR10: '../../datasets/CIFAR10/' } const cifar10Mock: Task = { @@ -57,7 +57,7 @@ const FILES = Map(DIRS).map((readFilesFromDir)).toObject() describe('image loader', () => { it('loads single sample without label', async () => { - const file = '../../example_training_data/9-mnist-example.png' + const file = '../../datasets/9-mnist-example.png' const singletonDataset = await LOADERS.MNIST.load(file) const imageContent = tfNode.decodeImage(fs.readFileSync(file)) await Promise.all((await singletonDataset.toArrayForTest()).map(async (entry) => { diff --git a/discojs/discojs-node/src/data/index.ts b/discojs/discojs-node/src/data/index.ts index 612a1c891..685b18afb 100644 --- a/discojs/discojs-node/src/data/index.ts +++ b/discojs/discojs-node/src/data/index.ts @@ -1,2 +1,3 @@ export { ImageLoader as NodeImageLoader } from './image_loader' export { TabularLoader as NodeTabularLoader } from './tabular_loader' +export { TextLoader as NodeTextLoader } from './text_loader' diff --git a/discojs/discojs-node/src/data/tabular_loader.spec.ts b/discojs/discojs-node/src/data/tabular_loader.spec.ts index 4cd2d241c..c4d45eadb 100644 --- a/discojs/discojs-node/src/data/tabular_loader.spec.ts +++ b/discojs/discojs-node/src/data/tabular_loader.spec.ts @@ -6,7 +6,7 @@ import type { Task } from '@epfml/discojs-core' import { TabularLoader } from './tabular_loader' -const inputFiles = ['../../example_training_data/titanic_train.csv'] +const inputFiles = ['../../datasets/titanic_train.csv'] const titanicMock: Task = { id: 'titanic', diff --git a/discojs/discojs-node/src/data/text_loader.ts b/discojs/discojs-node/src/data/text_loader.ts index 0044e07fa..9698df611 100644 --- a/discojs/discojs-node/src/data/text_loader.ts +++ b/discojs/discojs-node/src/data/text_loader.ts @@ -1,30 +1,14 @@ -// import fs from 'node:fs' +import fs from 'node:fs/promises' +import { data as tfData } from '@tensorflow/tfjs-node' -// import split2 from 'split2' +import { data } from '@epfml/discojs-core' -// import { tf } from '../..' -// import { TextLoader } from '../../core/dataset/data_loader/text_loader' -// import { Dataset } from '../../core/dataset' -// import { DataConfig } from '../../core/dataset/data_loader' +export class TextLoader extends data.TextLoader { + async loadDatasetFrom (source: string): Promise { + // TODO sure, good idea to load the whole dataset in memory #irony + const content = await fs.readFile(source) + const file = new tfData.FileDataSource(content) -// export class NodeTextLoader extends TextLoader { -// async loadDatasetFrom (source: string, config?: DataConfig): Promise { -// const prefix = 'file://' -// if (source.slice(0, 7) !== prefix) { -// source = prefix + source -// } -// // create stream being read by generator -// const stream = fs.createReadStream(source, { encoding: 'utf-8' }) -// // eslint-disable-next-line @typescript-eslint/no-this-alias -// const self = this - -// async function * dataGenerator (): AsyncGenerator { -// // TODO @s314cy -// const withLabels = config?.labels !== undefined -// stream.pipe(split2()) -// stream.on('data', (data) => yield self.tokenize(data)) -// } - -// return tf.data.generator(dataGenerator) -// } -// } + return new tfData.TextLineDataset(file) + } +} diff --git a/discojs/discojs-web/src/data/text_loader.ts b/discojs/discojs-web/src/data/text_loader.ts index 702368ca3..4fdd3c5fc 100644 --- a/discojs/discojs-web/src/data/text_loader.ts +++ b/discojs/discojs-web/src/data/text_loader.ts @@ -3,12 +3,8 @@ import tf from '@tensorflow/tfjs' import { data } from '@epfml/discojs-core' export class TextLoader extends data.TextLoader { - async loadDatasetFrom (source: File, config?: Record): Promise { + async loadDatasetFrom (source: File): Promise { const file = new tf.data.FileDataSource(source) - if (config !== undefined) { - return new tf.data.CSVDataset(file, config) - } else { - return new tf.data.TextLineDataset(file) - } + return new tf.data.TextLineDataset(file) } } diff --git a/docs/examples/.gitignore b/docs/examples/.gitignore new file mode 100644 index 000000000..1c0e0952a --- /dev/null +++ b/docs/examples/.gitignore @@ -0,0 +1,2 @@ +# saved models +/models/ diff --git a/docs/examples/training.ts b/docs/examples/training.ts index 7ba920ebf..a415a3ef6 100644 --- a/docs/examples/training.ts +++ b/docs/examples/training.ts @@ -69,7 +69,7 @@ async function filesFromFolder (dir: string, folder: string): Promise } async function loadSimpleFaceData (task: Task): Promise { - const dir = '../../example_training_data/simple_face/' + const dir = '../../datasets/simple_face/' const youngFolders = ['child'] const oldFolders = ['adult'] @@ -85,7 +85,7 @@ async function loadSimpleFaceData (task: Task): Promise { } async function loadTitanicData (task: Task): Promise { - const files = ['../../example_training_data/titanic_train.csv'] + const files = ['../../datasets/titanic_train.csv'] const titanicTask = defaultTasks.titanic.getTask() return await new NodeTabularLoader(task, ',').loadAll(files, { features: titanicTask.trainingInformation.inputColumns, diff --git a/get_training_data.sh b/get_training_data.sh deleted file mode 100755 index a6f29c111..000000000 --- a/get_training_data.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/sh -DIR="$( cd "$( dirname "$0" )" ; pwd -P )" -ARCHIVE="example_training_data.tar.gz" - -cd $DIR -curl -L "http://deai-313515.appspot.com.storage.googleapis.com/$ARCHIVE" -o $ARCHIVE -tar -xf $ARCHIVE diff --git a/server/.gitignore b/server/.gitignore index 49cb47e5b..1c0e0952a 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -1,9 +1,2 @@ -# Server-related files -model.json -weights.bin -weights_round*.json - -# TODO be more specific -milestones/ -dist/ -models/ \ No newline at end of file +# saved models +/models/ diff --git a/server/src/router/decentralized/server.ts b/server/src/router/decentralized/server.ts index 86a4fcb9d..c01af943c 100644 --- a/server/src/router/decentralized/server.ts +++ b/server/src/router/decentralized/server.ts @@ -1,12 +1,9 @@ import { v4 as randomUUID } from 'uuid' -import type express from 'express' import msgpack from 'msgpack-lite' import type WebSocket from 'ws' -import type { ParamsDictionary } from 'express-serve-static-core' -import type { ParsedQs } from 'qs' import { Map, Set } from 'immutable' -import type { Model, Task, TaskID } from '@epfml/discojs-core' +import type { Task, TaskID } from '@epfml/discojs-core' import { client } from '@epfml/discojs-core' import { Server } from '../server' @@ -27,8 +24,8 @@ export class Decentralized extends Server { protected readonly description = 'Disco Decentralized Server' - protected buildRoute (task: Task): string { - return `/${task.id}` + protected buildRoute (task: TaskID): string { + return `/${task}` } public isValidUrl (url: string | undefined): boolean { @@ -43,20 +40,9 @@ export class Decentralized extends Server { ) } - protected initTask (task: Task, model: Model): void {} - - protected handle ( - task: Task, - ws: WebSocket, - model: Model, - req: express.Request< - ParamsDictionary, - any, - any, - ParsedQs, - Record - > - ): void { + protected initTask (): void {} + + protected handle (task: Task, ws: WebSocket): void { // TODO @s314cy: add to task definition, to be used as threshold in aggregator const minimumReadyPeers = task.trainingInformation?.minimumReadyPeers ?? 3 diff --git a/server/src/router/federated/server.ts b/server/src/router/federated/server.ts index f29aa4c12..deca543b5 100644 --- a/server/src/router/federated/server.ts +++ b/server/src/router/federated/server.ts @@ -1,4 +1,3 @@ -import type express from 'express' import type WebSocket from 'ws' import { v4 as randomUUID } from 'uuid' import { List, Map } from 'immutable' @@ -71,8 +70,8 @@ export class Federated extends Server { protected readonly description = 'Disco Federated Server' - protected buildRoute (task: Task): string { - return `/${task.id}` + protected buildRoute (task: TaskID): string { + return `/${task}` } public isValidUrl (url: string | undefined): boolean { @@ -94,28 +93,29 @@ export class Federated extends Server { * one resolved and awaits until it resolves. The promise is used in createPromiseForWeights. * @param aggregator The aggregation handler */ - private async storeAggregationResult (aggregator: aggregators.Aggregator): Promise { + private async storeAggregationResult (task: TaskID, aggregator: aggregators.Aggregator): Promise { // Create a promise on the future aggregated weights const result = aggregator.receiveResult() // Store the promise such that it is accessible from other methods - this.results = this.results.set(aggregator.task.id, result) + this.results = this.results.set(task, result) // The promise resolves once the server received enough contributions (through the handle method) // and the aggregator aggregated the weights. await result // Update the server round with the aggregator round - this.rounds = this.rounds.set(aggregator.task.id, aggregator.round) + this.rounds = this.rounds.set(task, aggregator.round) // Create a new promise for the next round - void this.storeAggregationResult(aggregator) + // TODO weird usage, should be handled inside of aggregator + void this.storeAggregationResult(task, aggregator) } - protected initTask (task: Task, model: Model): void { - const aggregator = new aggregators.MeanAggregator(task, model) + protected initTask (task: TaskID, model: Model): void { + const aggregator = new aggregators.MeanAggregator(model) - this.aggregators = this.aggregators.set(task.id, aggregator) - this.informants = this.informants.set(task.id, new AsyncInformant(aggregator)) - this.rounds = this.rounds.set(task.id, 0) + this.aggregators = this.aggregators.set(task, aggregator) + this.informants = this.informants.set(task, new AsyncInformant(aggregator)) + this.rounds = this.rounds.set(task, 0) - void this.storeAggregationResult(aggregator) + void this.storeAggregationResult(task, aggregator) } /** @@ -127,17 +127,21 @@ export class Federated extends Server { * @param clientId the clientID of the contribution * @param ws the websocket through which send the aggregated weights */ - private async addContributionAndSendModel (msg: messages.SendPayload, task: Task, - clientId: client.NodeID, ws: WebSocket): Promise { + private async addContributionAndSendModel ( + msg: messages.SendPayload, + task: TaskID, + clientId: client.NodeID, + ws: WebSocket + ): Promise { const { payload, round } = msg - const aggregator = this.aggregators.get(task.id) + const aggregator = this.aggregators.get(task) if (!(Array.isArray(payload) && payload.every((e) => typeof e === 'number'))) { throw new Error('received invalid weights format') } if (aggregator === undefined) { - throw new Error(`received weights for unknown task: ${task.id}`) + throw new Error(`received weights for unknown task: ${task}`) } // It is important to create a promise for the weights BEFORE adding the contribution @@ -169,12 +173,12 @@ export class Federated extends Server { * @param ws the websocket through which send the aggregated weights */ private async createPromiseForWeights ( - task: Task, + task: TaskID, aggregator: aggregators.Aggregator, ws: WebSocket): Promise { - const promisedResult = this.results.get(task.id) + const promisedResult = this.results.get(task) if (promisedResult === undefined) { - throw new Error(`result promise was not set for task ${task.id}`) + throw new Error(`result promise was not set for task ${task}`) } // Wait for aggregation result to resolve with timeout, giving the network a time window @@ -197,12 +201,7 @@ export class Federated extends Server { .catch(console.error) } - protected handle ( - task: Task, - ws: WebSocket, - model: Model, - req: express.Request - ): void { + protected handle (task: Task, ws: WebSocket, model: Model): void { const taskAggregator = this.aggregators.get(task.id) if (taskAggregator === undefined) { throw new Error('connecting to a non-existing task') @@ -225,7 +224,7 @@ export class Federated extends Server { let aggregator = this.aggregators.get(task.id) if (aggregator === undefined) { - aggregator = new aggregators.MeanAggregator(task) + aggregator = new aggregators.MeanAggregator() this.aggregators = this.aggregators.set(task.id, aggregator) } console.info('client', clientId, 'joined', task.id) @@ -241,7 +240,7 @@ export class Federated extends Server { if (model === undefined) { throw new Error('aggregator model was not set') } - this.addContributionAndSendModel(msg, task, clientId, ws) + this.addContributionAndSendModel(msg, task.id, clientId, ws) .catch(console.error) } else if (msg.type === MessageTypes.ReceiveServerStatistics) { const statistics = this.informants @@ -264,14 +263,14 @@ export class Federated extends Server { throw new Error('aggregator model was not set') } - this.createPromiseForWeights(task, aggregator, ws) + this.createPromiseForWeights(task.id, aggregator, ws) .catch(console.error) } else if (msg.type === MessageTypes.SendMetadata) { const { round, key, value } = msg this.logsAppend(task.id, clientId, MessageTypes.SendMetadata, round) - if (this.metadataMap.hasIn([task.id, round, clientId, key])) { + if (this.metadataMap.hasIn([task, round, clientId, key])) { throw new Error('metadata already set') } this.metadataMap = this.metadataMap.setIn( @@ -317,13 +316,13 @@ export class Federated extends Server { /** * Appends a request to the logs. - * @param taskId The task id for which the request was made + * @param task The task id for which the request was made * @param nodeId The node id who made the request * @param type The request type * @param round The round for which the request was made */ private logsAppend ( - taskId: TaskID, + task: TaskID, nodeId: client.NodeID, type: MessageTypes, round: number | undefined = undefined @@ -334,7 +333,7 @@ export class Federated extends Server { this.logs = this.logs.push({ timestamp: new Date(), - task: taskId, + task, round, nodeId, type diff --git a/server/src/router/server.ts b/server/src/router/server.ts index a7eed2e27..a433dfe25 100644 --- a/server/src/router/server.ts +++ b/server/src/router/server.ts @@ -2,7 +2,7 @@ import express from 'express' import type expressWS from 'express-ws' import type WebSocket from 'ws' -import type { Model, Task } from '@epfml/discojs-core' +import type { Model, Task, TaskID } from '@epfml/discojs-core' import type { TasksAndModels } from '../tasks' @@ -35,9 +35,9 @@ export abstract class Server { private onNewTask (task: Task, model: Model): void { this.tasks.push(task.id) - this.initTask(task, model) + this.initTask(task.id, model) - this.ownRouter.ws(this.buildRoute(task), (ws, req) => { + this.ownRouter.ws(this.buildRoute(task.id), (ws, req) => { if (this.isValidUrl(req.url)) { this.handle(task, ws, model, req) } else { @@ -63,9 +63,9 @@ export abstract class Server { protected abstract readonly description: string - protected abstract buildRoute (task: Task): string + protected abstract buildRoute (task: TaskID): string - protected abstract initTask (task: Task, model: Model): void + protected abstract initTask (task: TaskID, model: Model): void protected abstract handle ( task: Task, diff --git a/server/tests/client/decentralized.spec.ts b/server/tests/client/decentralized.spec.ts index 751c08c38..3e25d558e 100644 --- a/server/tests/client/decentralized.spec.ts +++ b/server/tests/client/decentralized.spec.ts @@ -10,7 +10,7 @@ const TASK = defaultTasks.titanic.getTask() function test ( name: string, Client: new (url: URL, task: Task, aggregator: aggregators.Aggregator) => clients.Client, - Aggregator: new (task: Task) => aggregators.Aggregator + Aggregator: new () => aggregators.Aggregator ): void { describe(`decentralized ${name} client`, function () { this.timeout(30_000) @@ -21,7 +21,7 @@ function test ( afterEach(() => { server?.close() }) it('connect and disconnect from valid task', async () => { - const aggregator = new Aggregator(TASK) + const aggregator = new Aggregator() const client = new Client(url, TASK, aggregator) await client.connect() diff --git a/server/tests/client/federated.spec.ts b/server/tests/client/federated.spec.ts index c7ddbb8ce..a12839ff2 100644 --- a/server/tests/client/federated.spec.ts +++ b/server/tests/client/federated.spec.ts @@ -15,7 +15,7 @@ describe('federated client', function () { afterEach(() => { server?.close() }) it('connect to & disconnect from valid task', async () => { - const client = new clients.federated.FederatedClient(url, TASK, new aggregators.MeanAggregator(TASK)) + const client = new clients.federated.FederatedClient(url, TASK, new aggregators.MeanAggregator()) await client.connect() await client.disconnect() }) @@ -36,7 +36,7 @@ describe('federated client', function () { dataType: 'tabular' } }, - new aggregators.MeanAggregator(TASK) + new aggregators.MeanAggregator() ) try { @@ -49,7 +49,7 @@ describe('federated client', function () { }) it('checks that getAsyncWeightInformantStatistics returns a JSON with the expected statistics', async () => { - const client = new clients.federated.FederatedClient(url, TASK, new aggregators.MeanAggregator(TASK)) + const client = new clients.federated.FederatedClient(url, TASK, new aggregators.MeanAggregator()) await client.connect() const ti = new informant.FederatedInformant(TASK, 0) diff --git a/server/tests/e2e/decentralized.spec.ts b/server/tests/e2e/decentralized.spec.ts index ba9abed0e..29512b1bf 100644 --- a/server/tests/e2e/decentralized.spec.ts +++ b/server/tests/e2e/decentralized.spec.ts @@ -107,7 +107,7 @@ describe('end-to-end decentralized', function () { * The clients have model dimension of 4 model updates to share, which can be seen as their input parameter in makeClient. */ async function reachConsensus ( - Aggregator: new (task: Task) => MockAggregator, + Aggregator: new () => MockAggregator, rounds = 1 ): Promise { // Expect the clients to reach the mean consensus, for both the mean and secure aggregators diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index 5bbd4295c..bb6287e2a 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -2,21 +2,21 @@ import fs from 'node:fs/promises' import path from 'node:path' import type { Server } from 'node:http' import { Range } from 'immutable' -import { assert } from 'chai' +import { assert, expect } from 'chai' import type { WeightsContainer } from '@epfml/discojs-core' import { - Disco, TrainingSchemes, client as clients, + Disco, TrainingSchemes, client as clients, data, aggregator as aggregators, informant, defaultTasks } from '@epfml/discojs-core' -import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node' +import { NodeImageLoader, NodeTabularLoader, NodeTextLoader } from '@epfml/discojs-node' import { startServer } from '../../src' const SCHEME = TrainingSchemes.FEDERATED describe('end-to-end federated', function () { - this.timeout(120_000) + this.timeout(100_000) let server: Server let url: URL @@ -24,7 +24,7 @@ describe('end-to-end federated', function () { afterEach(() => { server?.close() }) async function cifar10user (): Promise { - const dir = '../example_training_data/CIFAR10/' + const dir = '../datasets/CIFAR10/' const files = (await fs.readdir(dir)).map((file) => path.join(dir, file)) const labels = Range(0, 24).map((label) => (label % 10).toString()).toArray() @@ -32,7 +32,7 @@ describe('end-to-end federated', function () { const data = await new NodeImageLoader(cifar10Task).loadAll(files, { labels, shuffle: false }) - const aggregator = new aggregators.MeanAggregator(cifar10Task) + const aggregator = new aggregators.MeanAggregator() const client = new clients.federated.FederatedClient(url, cifar10Task, aggregator) const disco = new Disco(cifar10Task, { scheme: SCHEME, client }) @@ -46,7 +46,7 @@ describe('end-to-end federated', function () { } async function titanicUser (): Promise { - const files = ['../example_training_data/titanic_train.csv'] + const files = ['../datasets/titanic_train.csv'] const titanicTask = defaultTasks.titanic.getTask() titanicTask.trainingInformation.epochs = 5 @@ -59,7 +59,7 @@ describe('end-to-end federated', function () { } )) - const aggregator = new aggregators.MeanAggregator(titanicTask) + const aggregator = new aggregators.MeanAggregator() const client = new clients.federated.FederatedClient(url, titanicTask, aggregator) const trainingInformant = new informant.FederatedInformant(titanicTask, 10) const disco = new Disco(titanicTask, { scheme: SCHEME, client, aggregator, informant: trainingInformant }) @@ -81,13 +81,42 @@ describe('end-to-end federated', function () { return aggregator.model.weights } + async function wikitextUser (): Promise { + const task = defaultTasks.wikitext.getTask() + const loader = new NodeTextLoader(task) + const dataSplit: data.DataSplit = { + train: await data.TextData.init((await loader.load('../datasets/wikitext/wiki.train.tokens')), task), + validation: await data.TextData.init(await loader.load('../datasets/wikitext/wiki.valid.tokens'), task) + } + + const aggregator = new aggregators.MeanAggregator() + const client = new clients.federated.FederatedClient(url, task, aggregator) + const trainingInformant = new informant.FederatedInformant(task, 10) + const disco = new Disco(task, { scheme: SCHEME, client, aggregator, informant: trainingInformant }) + + await disco.fit(dataSplit) + await disco.close() + + expect(trainingInformant.losses.first()).to.be.above(trainingInformant.losses.last()) + } + it('two cifar10 users reach consensus', async () => { + this.timeout(90_000) + const [m1, m2] = await Promise.all([cifar10user(), cifar10user()]) assert.isTrue(m1.equals(m2)) }) it('two titanic users reach consensus', async () => { + this.timeout(30_000) + const [m1, m2] = await Promise.all([titanicUser(), titanicUser()]) assert.isTrue(m1.equals(m2)) }) + + it('trains wikitext', async () => { + this.timeout(120_000) + + await wikitextUser() + }) })