Skip to content

Commit 6900d93

Browse files
authored
Fixes issue 160 (#162)
* Fixes issue 160 - [x] Verified changes - [x] Tests, coverage, ran, checked * MOBETTA Testing - [x] Verified changes - [x] Tests, coverage, ran, checked * Bumps Version
1 parent 572e421 commit 6900d93

File tree

7 files changed

+579
-312
lines changed

7 files changed

+579
-312
lines changed

browser.js

Lines changed: 436 additions & 257 deletions
Large diffs are not rendered by default.

browser.min.js

Lines changed: 35 additions & 36 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/neural-network.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "brain.js",
33
"description": "Neural network library",
4-
"version": "1.1.1",
4+
"version": "1.1.2",
55
"author": "Heather Arthur <[email protected]>",
66
"repository": {
77
"type": "git",

src/neural-network.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ export default class NeuralNetwork {
331331
_getTrainOptsJSON() {
332332
return Object.keys(NeuralNetwork.trainDefaults)
333333
.reduce((opts, opt) => {
334+
if (opt === 'timeout' && this.trainOpts[opt] === Infinity) return opts;
334335
if (this.trainOpts[opt]) opts[opt] = this.trainOpts[opt];
335336
if (opt === 'log') opts.log = typeof opts.log === 'function';
336337
return opts;

test/base/json.js

Lines changed: 104 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ describe('JSON', () => {
2727
trainingOpts.log = true;
2828

2929
const serialized = originalNet.toJSON();
30-
const serializedNet = new NeuralNetwork().fromJSON(serialized);
30+
const serializedNet = new NeuralNetwork()
31+
.fromJSON(
32+
JSON.parse(
33+
JSON.stringify(serialized)
34+
)
35+
);
3136

3237
const input = {'0' : Math.random(), b: Math.random()};
3338
describe('.toJSON()', () => {
@@ -71,39 +76,39 @@ describe('JSON', () => {
7176

7277
describe('.trainOpts', () => {
7378
it('training options iterations', () => {
74-
assert.equal(trainingOpts.iterations, serialized.trainOpts.iterations, `trainingOpts.are: ${trainingOpts.iterations} serialized should be the same but are: ${serialized.trainOpts.iterations}`);
79+
assert.equal(trainingOpts.iterations, serialized.trainOpts.iterations, `trainingOpts are: ${trainingOpts.iterations} serialized should be the same but are: ${serialized.trainOpts.iterations}`);
7580
});
7681

7782
it('training options errorThresh', () => {
78-
assert.equal(trainingOpts.errorThresh, serialized.trainOpts.errorThresh, `trainingOpts.are: ${trainingOpts.errorThresh} serialized should be the same but are: ${serialized.trainOpts.errorThresh}`);
83+
assert.equal(trainingOpts.errorThresh, serialized.trainOpts.errorThresh, `trainingOpts are: ${trainingOpts.errorThresh} serialized should be the same but are: ${serialized.trainOpts.errorThresh}`);
7984
});
8085

8186
it('training options log', () => {
8287
assert.equal(trainingOpts.log, serialized.trainOpts.log, `log are: ${trainingOpts.log} serialized should be the same but are: ${serialized.trainOpts.log}`);
8388
});
8489

8590
it('training options logPeriod', () => {
86-
assert.equal(trainingOpts.logPeriod, serialized.trainOpts.logPeriod, `trainingOpts.are: ${trainingOpts.logPeriod} serialized should be the same but are: ${serialized.trainOpts.logPeriod}`);
91+
assert.equal(trainingOpts.logPeriod, serialized.trainOpts.logPeriod, `trainingOpts are: ${trainingOpts.logPeriod} serialized should be the same but are: ${serialized.trainOpts.logPeriod}`);
8792
});
8893

8994
it('training options learningRate', () => {
90-
assert.equal(trainingOpts.learningRate, serialized.trainOpts.learningRate, `trainingOpts.are: ${trainingOpts.learningRate} serialized should be the same but are: ${serialized.trainOpts.learningRate}`);
95+
assert.equal(trainingOpts.learningRate, serialized.trainOpts.learningRate, `trainingOpts are: ${trainingOpts.learningRate} serialized should be the same but are: ${serialized.trainOpts.learningRate}`);
9196
});
9297

9398
it('training options momentum', () => {
94-
assert.equal(trainingOpts.momentum, serialized.trainOpts.momentum, `trainingOpts.are: ${trainingOpts.momentum} serialized should be the same but are: ${serialized.trainOpts.momentum}`);
99+
assert.equal(trainingOpts.momentum, serialized.trainOpts.momentum, `trainingOpts are: ${trainingOpts.momentum} serialized should be the same but are: ${serialized.trainOpts.momentum}`);
95100
});
96101

97102
it('training options callback', () => {
98-
assert.equal(trainingOpts.callback, serialized.trainOpts.callback, `trainingOpts.are: ${trainingOpts.callback} serialized should be the same but are: ${serialized.trainOpts.callback}`);
103+
assert.equal(trainingOpts.callback, serialized.trainOpts.callback, `trainingOpts are: ${trainingOpts.callback} serialized should be the same but are: ${serialized.trainOpts.callback}`);
99104
});
100105

101106
it('training options callbackPeriod', () => {
102-
assert.equal(trainingOpts.callbackPeriod, serialized.trainOpts.callbackPeriod, `trainingOpts.are: ${trainingOpts.callbackPeriod} serialized should be the same but are: ${serialized.trainOpts.callbackPeriod}`);
107+
assert.equal(trainingOpts.callbackPeriod, serialized.trainOpts.callbackPeriod, `trainingOpts are: ${trainingOpts.callbackPeriod} serialized should be the same but are: ${serialized.trainOpts.callbackPeriod}`);
103108
});
104109

105110
it('training options timeout', () => {
106-
assert.equal(trainingOpts.timeout, serialized.trainOpts.timeout, `trainingOpts.are: ${trainingOpts.timeout} serialized should be the same but are: ${serialized.trainOpts.timeout}`);
111+
assert.equal(trainingOpts.timeout, serialized.trainOpts.timeout, `trainingOpts are: ${trainingOpts.timeout} serialized should be the same but are: ${serialized.trainOpts.timeout}`);
107112
});
108113
});
109114

@@ -137,11 +142,11 @@ describe('JSON', () => {
137142

138143
describe('.trainOpts', () => {
139144
it('training options iterations', () => {
140-
assert.equal(trainingOpts.iterations, serializedNet.trainOpts.iterations, `trainingOpts.are: ${trainingOpts.iterations} serializedNet should be the same but are: ${serializedNet.trainOpts.iterations}`);
145+
assert.equal(trainingOpts.iterations, serializedNet.trainOpts.iterations, `trainingOpts are: ${trainingOpts.iterations} serializedNet should be the same but are: ${serializedNet.trainOpts.iterations}`);
141146
});
142147

143148
it('training options errorThresh', () => {
144-
assert.equal(trainingOpts.errorThresh, serializedNet.trainOpts.errorThresh, `trainingOpts.are: ${trainingOpts.errorThresh} serializedNet should be the same but are: ${serializedNet.trainOpts.errorThresh}`);
149+
assert.equal(trainingOpts.errorThresh, serializedNet.trainOpts.errorThresh, `trainingOpts are: ${trainingOpts.errorThresh} serializedNet should be the same but are: ${serializedNet.trainOpts.errorThresh}`);
145150
});
146151

147152
it('training options log', () => {
@@ -150,27 +155,27 @@ describe('JSON', () => {
150155
});
151156

152157
it('training options logPeriod', () => {
153-
assert.equal(trainingOpts.logPeriod, serializedNet.trainOpts.logPeriod, `trainingOpts.are: ${trainingOpts.logPeriod} serializedNet should be the same but are: ${serializedNet.trainOpts.logPeriod}`);
158+
assert.equal(trainingOpts.logPeriod, serializedNet.trainOpts.logPeriod, `trainingOpts are: ${trainingOpts.logPeriod} serializedNet should be the same but are: ${serializedNet.trainOpts.logPeriod}`);
154159
});
155160

156161
it('training options learningRate', () => {
157-
assert.equal(trainingOpts.learningRate, serializedNet.trainOpts.learningRate, `trainingOpts.are: ${trainingOpts.learningRate} serializedNet should be the same but are: ${serializedNet.trainOpts.learningRate}`);
162+
assert.equal(trainingOpts.learningRate, serializedNet.trainOpts.learningRate, `trainingOpts are: ${trainingOpts.learningRate} serializedNet should be the same but are: ${serializedNet.trainOpts.learningRate}`);
158163
});
159164

160165
it('training options momentum', () => {
161-
assert.equal(trainingOpts.momentum, serializedNet.trainOpts.momentum, `trainingOpts.are: ${trainingOpts.momentum} serializedNet should be the same but are: ${serializedNet.trainOpts.momentum}`);
166+
assert.equal(trainingOpts.momentum, serializedNet.trainOpts.momentum, `trainingOpts are: ${trainingOpts.momentum} serializedNet should be the same but are: ${serializedNet.trainOpts.momentum}`);
162167
});
163168

164169
it('training options callback', () => {
165-
assert.equal(trainingOpts.callback, serializedNet.trainOpts.callback, `trainingOpts.are: ${trainingOpts.callback} serializedNet should be the same but are: ${serializedNet.trainOpts.callback}`);
170+
assert.equal(trainingOpts.callback, serializedNet.trainOpts.callback, `trainingOpts are: ${trainingOpts.callback} serializedNet should be the same but are: ${serializedNet.trainOpts.callback}`);
166171
});
167172

168173
it('training options callbackPeriod', () => {
169-
assert.equal(trainingOpts.callbackPeriod, serializedNet.trainOpts.callbackPeriod, `trainingOpts.are: ${trainingOpts.callbackPeriod} serializedNet should be the same but are: ${serializedNet.trainOpts.callbackPeriod}`);
174+
assert.equal(trainingOpts.callbackPeriod, serializedNet.trainOpts.callbackPeriod, `trainingOpts are: ${trainingOpts.callbackPeriod} serializedNet should be the same but are: ${serializedNet.trainOpts.callbackPeriod}`);
170175
});
171176

172177
it('training options timeout', () => {
173-
assert.equal(trainingOpts.timeout, serializedNet.trainOpts.timeout, `trainingOpts.are: ${trainingOpts.timeout} serializedNet should be the same but are: ${serializedNet.trainOpts.timeout}`);
178+
assert.equal(trainingOpts.timeout, serializedNet.trainOpts.timeout, `trainingOpts are: ${trainingOpts.timeout} serializedNet should be the same but are: ${serializedNet.trainOpts.timeout}`);
174179
});
175180
});
176181
});
@@ -192,3 +197,85 @@ describe('JSON', () => {
192197
})
193198
});
194199
});
200+
201+
202+
describe('default net json', () => {
203+
const originalNet = new NeuralNetwork();
204+
205+
originalNet.train([
206+
{
207+
input: {'0': Math.random(), b: Math.random()},
208+
output: {c: Math.random(), '0': Math.random()}
209+
}, {
210+
input: {'0': Math.random(), b: Math.random()},
211+
output: {c: Math.random(), '0': Math.random()}
212+
}
213+
]);
214+
215+
const serialized = originalNet.toJSON();
216+
const serializedNet = new NeuralNetwork()
217+
.fromJSON(
218+
JSON.parse(
219+
JSON.stringify(serialized)
220+
)
221+
);
222+
223+
const input = {'0' : Math.random(), b: Math.random()};
224+
225+
describe('.trainOpts', () => {
226+
it('training options iterations', () => {
227+
assert.equal(originalNet.trainOpts.iterations, serializedNet.trainOpts.iterations, `originalNet.trainOpts are: ${originalNet.trainOpts.iterations} serializedNet should be the same but are: ${serializedNet.trainOpts.iterations}`);
228+
});
229+
230+
it('training options errorThresh', () => {
231+
assert.equal(originalNet.trainOpts.errorThresh, serializedNet.trainOpts.errorThresh, `originalNet.trainOpts are: ${originalNet.trainOpts.errorThresh} serializedNet should be the same but are: ${serializedNet.trainOpts.errorThresh}`);
232+
});
233+
234+
it('training options log', () => {
235+
// Should have inflated to console.log
236+
assert.equal(originalNet.trainOpts.log, serializedNet.trainOpts.log, `log are: ${originalNet.trainOpts.log} serializedNet should be the same but are: ${serializedNet.trainOpts.log}`);
237+
});
238+
239+
it('training options logPeriod', () => {
240+
assert.equal(originalNet.trainOpts.logPeriod, serializedNet.trainOpts.logPeriod, `originalNet.trainOpts are: ${originalNet.trainOpts.logPeriod} serializedNet should be the same but are: ${serializedNet.trainOpts.logPeriod}`);
241+
});
242+
243+
it('training options learningRate', () => {
244+
assert.equal(originalNet.trainOpts.learningRate, serializedNet.trainOpts.learningRate, `originalNet.trainOpts are: ${originalNet.trainOpts.learningRate} serializedNet should be the same but are: ${serializedNet.trainOpts.learningRate}`);
245+
});
246+
247+
it('training options momentum', () => {
248+
assert.equal(originalNet.trainOpts.momentum, serializedNet.trainOpts.momentum, `originalNet.trainOpts are: ${originalNet.trainOpts.momentum} serializedNet should be the same but are: ${serializedNet.trainOpts.momentum}`);
249+
});
250+
251+
it('training options callback', () => {
252+
assert.equal(originalNet.trainOpts.callback, serializedNet.trainOpts.callback, `originalNet.trainOpts are: ${originalNet.trainOpts.callback} serializedNet should be the same but are: ${serializedNet.trainOpts.callback}`);
253+
});
254+
255+
it('training options callbackPeriod', () => {
256+
assert.equal(originalNet.trainOpts.callbackPeriod, serializedNet.trainOpts.callbackPeriod, `originalNet.trainOpts are: ${originalNet.trainOpts.callbackPeriod} serializedNet should be the same but are: ${serializedNet.trainOpts.callbackPeriod}`);
257+
});
258+
259+
it('training options timeout', () => {
260+
console.log(originalNet.trainOpts.timeout);
261+
console.log(serializedNet.trainOpts.timeout);
262+
assert.equal(originalNet.trainOpts.timeout, serializedNet.trainOpts.timeout, `originalNet.trainOpts are: ${originalNet.trainOpts.timeout} serializedNet should be the same but are: ${serializedNet.trainOpts.timeout}`);
263+
});
264+
});
265+
266+
it('can run originalNet, and serializedNet, with same output', () => {
267+
const output1 = originalNet.run(input);
268+
const output2 = serializedNet.run(input);
269+
assert.deepEqual(output2, output1,
270+
'loading json serialized network failed');
271+
});
272+
273+
it('if json.trainOpts is not set, ._updateTrainingOptions() is not called and activation defaults to sigmoid', () => {
274+
const net = new NeuralNetwork();
275+
net._updateTrainingOptions = () => {
276+
throw new Error('_updateTrainingOptions was called');
277+
};
278+
net.fromJSON({ sizes: [], layers: [] });
279+
assert(net.activation === 'sigmoid');
280+
})
281+
})

0 commit comments

Comments
 (0)