diff --git a/src/neural-network.js b/src/neural-network.js index 4ce3bfa11..c28075d70 100644 --- a/src/neural-network.js +++ b/src/neural-network.js @@ -32,6 +32,7 @@ export default class NeuralNetwork { static get defaults() { return { + leakyReluAlpha: 0.01, binaryThresh: 0.5, hiddenLayers: [3], // array of ints for the sizes of the hidden layers in the network activation: 'sigmoid' // Supported activation types ['sigmoid', 'relu', 'leaky-relu', 'tanh'] @@ -249,7 +250,7 @@ export default class NeuralNetwork { _runInputLeakyRelu(input) { this.outputs[0] = input; // set output state of input layer - + let alpha = this.leakyReluAlpha; let output = null; for (let layer = 1; layer <= this.outputLayer; layer++) { for (let node = 0; node < this.sizes[layer]; node++) { @@ -260,7 +261,7 @@ export default class NeuralNetwork { sum += weights[k] * input[k]; } //leaky relu - this.outputs[layer][node] = (sum < 0 ? 0 : 0.01 * sum); + this.outputs[layer][node] = (sum < 0 ? 0 : alpha * sum); } output = input = this.outputs[layer]; } @@ -557,6 +558,7 @@ export default class NeuralNetwork { * @param target */ _calculateDeltasLeakyRelu(target) { + let alpha = this.leakyReluAlpha; for (let layer = this.outputLayer; layer >= 0; layer--) { for (let node = 0; node < this.sizes[layer]; node++) { let output = this.outputs[layer][node]; @@ -572,7 +574,7 @@ export default class NeuralNetwork { } } this.errors[layer][node] = error; - this.deltas[layer][node] = output > 0 ? error : 0.01 * error; + this.deltas[layer][node] = output > 0 ? error : alpha * error; } } } @@ -933,6 +935,7 @@ export default class NeuralNetwork { */ toFunction() { const activation = this.activation; + const leakyReluAlpha = this.leakyReluAlpha; let needsVar = false; function nodeHandle(layers, layerNumber, nodeKey) { if (layerNumber === 0) { @@ -962,7 +965,7 @@ export default class NeuralNetwork { } case 'leaky-relu': { needsVar = true; - return `((v=${result.join('')})<0?0:0.01*v)`; + return `((v=${result.join('')})<0?0:${leakyReluAlpha}*v)`; } case 'tanh': return `Math.tanh(${result.join('')})`; @@ -988,4 +991,4 @@ export default class NeuralNetwork { return new Function('input', `${ needsVar ? 'var v;' : '' }return ${result};`); } -} \ No newline at end of file +} diff --git a/test/base/options.js b/test/base/options.js index 18a62fee7..c4d5ab3ee 100644 --- a/test/base/options.js +++ b/test/base/options.js @@ -108,4 +108,10 @@ describe ('neural network constructor values', () => { var net = new brain.NeuralNetwork(opts); assert.equal(opts.activation, net.activation, `activation => ${net.activation} but should be ${opts.activation}`); }) + + it('leakyReluAlpha should be settable in the constructor', () => { + let opts = { leakyReluAlpha: 0.1337 }; + var net = new brain.NeuralNetwork(opts); + assert.equal(opts.leakyReluAlpha, net.leakyReluAlpha, `leakyReluAlpha => ${net.leakyReluAlpha} but should be ${opts.leakyReluAlpha}`); + }) }); \ No newline at end of file