Skip to content

Commit 465d067

Browse files
authored
Merge pull request #26 from huningxin/nsnet2
Optimize the nsnet2 example for explainer usage
2 parents e7cbbe1 + 37b9796 commit 465d067

File tree

8 files changed

+76
-114
lines changed

8 files changed

+76
-114
lines changed

nsnet2/denoiser.js

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,4 @@ export class Denoiser {
159159
this.log(`<b>Done.</b> Processed ${audioFrames} ` +
160160
`frames in <span class='text-primary'>${processTime}</span> ms.`, true);
161161
}
162-
163-
dispose() {
164-
this.nsnet.dispose();
165-
}
166162
}

nsnet2/nsnet2.js

Lines changed: 76 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -7,133 +7,99 @@ function sizeOfShape(shape) {
77
});
88
}
99

10+
async function buildConstantByNpy(builder, url) {
11+
const dataTypeMap = new Map([
12+
['f2', {type: 'float16', array: Uint16Array}],
13+
['f4', {type: 'float32', array: Float32Array}],
14+
['f8', {type: 'float64', array: Float64Array}],
15+
['i1', {type: 'int8', array: Int8Array}],
16+
['i2', {type: 'int16', array: Int16Array}],
17+
['i4', {type: 'int32', array: Int32Array}],
18+
['i8', {type: 'int64', array: BigInt64Array}],
19+
['u1', {type: 'uint8', array: Uint8Array}],
20+
['u2', {type: 'uint16', array: Uint16Array}],
21+
['u4', {type: 'uint32', array: Uint32Array}],
22+
['u8', {type: 'uint64', array: BigUint64Array}],
23+
]);
24+
const response = await fetch(url);
25+
const buffer = await response.arrayBuffer();
26+
const npArray = new numpy.Array(new Uint8Array(buffer));
27+
if (!dataTypeMap.has(npArray.dataType)) {
28+
throw new Error(`Data type ${npArray.dataType} is not supported.`);
29+
}
30+
const dimensions = npArray.shape;
31+
const type = dataTypeMap.get(npArray.dataType).type;
32+
const TypedArrayConstructor = dataTypeMap.get(npArray.dataType).array;
33+
const typedArray = new TypedArrayConstructor(sizeOfShape(dimensions));
34+
const dataView = new DataView(npArray.data.buffer);
35+
const littleEndian = npArray.byteOrder === '<';
36+
for (let i = 0; i < sizeOfShape(dimensions); ++i) {
37+
typedArray[i] = dataView[`get` + type[0].toUpperCase() + type.substr(1)](
38+
i * TypedArrayConstructor.BYTES_PER_ELEMENT, littleEndian);
39+
}
40+
return builder.constant({type, dimensions}, typedArray);
41+
}
42+
43+
/* eslint max-len: ["error", { "code": 130 }] */
44+
45+
// Noise Suppression Net 2 (NSNet2) Baseline Model for Deep Noise Suppression Challenge (DNS) 2021.
1046
export class NSNet2 {
1147
constructor() {
12-
this.baseUrl_ = './';
13-
this.model_ = null;
14-
this.compilation_ = null;
48+
this.model = null;
49+
this.compiledModel = null;
50+
this.frameSize = 161;
1551
this.hiddenSize = 400;
1652
}
1753

18-
async buildConstantByNpy(fileName) {
19-
const dataTypeMap = new Map([
20-
['f2', {type: 'float16', array: Uint16Array}],
21-
['f4', {type: 'float32', array: Float32Array}],
22-
['f8', {type: 'float64', array: Float64Array}],
23-
['i1', {type: 'int8', array: Int8Array}],
24-
['i2', {type: 'int16', array: Int16Array}],
25-
['i4', {type: 'int32', array: Int32Array}],
26-
['i8', {type: 'int64', array: BigInt64Array}],
27-
['u1', {type: 'uint8', array: Uint8Array}],
28-
['u2', {type: 'uint16', array: Uint16Array}],
29-
['u4', {type: 'uint32', array: Uint32Array}],
30-
['u8', {type: 'uint64', array: BigUint64Array}],
31-
]);
32-
const response = await fetch(this.baseUrl_ + fileName);
33-
const buffer = await response.arrayBuffer();
34-
const npArray = new numpy.Array(new Uint8Array(buffer));
35-
if (!dataTypeMap.has(npArray.dataType)) {
36-
throw new Error(`Data type ${npArray.dataType} is not supported.`);
37-
}
38-
const dimensions = npArray.shape;
39-
const type = dataTypeMap.get(npArray.dataType).type;
40-
const TypedArrayConstructor = dataTypeMap.get(npArray.dataType).array;
41-
const typedArray = new TypedArrayConstructor(sizeOfShape(dimensions));
42-
const dataView = new DataView(npArray.data.buffer);
43-
const littleEndian = npArray.byteOrder === '<';
44-
for (let i = 0; i < sizeOfShape(dimensions); ++i) {
45-
typedArray[i] = dataView[`get` + type[0].toUpperCase() + type.substr(1)](
46-
i * TypedArrayConstructor.BYTES_PER_ELEMENT, littleEndian);
47-
}
48-
return this.builder.constant({type, dimensions}, typedArray);
49-
}
50-
51-
async load(url, batchSize, frames) {
52-
this.baseUrl_ = url;
54+
async load(baseUrl, batchSize, frames) {
5355
const nn = navigator.ml.getNeuralNetworkContext();
5456
const builder = nn.createModelBuilder();
55-
this.builder = builder;
56-
57-
// Create constants
58-
const weight172 = await this.buildConstantByNpy('172.npy');
59-
const biasFcIn0 = await this.buildConstantByNpy('fc_in_0_bias.npy');
60-
const weight192 = await this.buildConstantByNpy('192.npy');
61-
const recurrentWeight193 = await this.buildConstantByNpy('193.npy');
62-
const data194 = await this.buildConstantByNpy('194.npy');
63-
const weight212 = await this.buildConstantByNpy('212.npy');
64-
const recurrentWeight213 = await this.buildConstantByNpy('213.npy');
65-
const data214 = await this.buildConstantByNpy('214.npy');
66-
const weight215 = await this.buildConstantByNpy('215.npy');
67-
const biasFcOut0 = await this.buildConstantByNpy('fc_out_0_bias.npy');
68-
const weight216 = await this.buildConstantByNpy('216.npy');
69-
const biasFcOut2 = await this.buildConstantByNpy('fc_out_2_bias.npy');
70-
const weight217 = await this.buildConstantByNpy('217.npy');
71-
const biasFcOut4 = await this.buildConstantByNpy('fc_out_4_bias.npy');
72-
73-
// Build up the network
74-
const hiddenSize = this.hiddenSize;
75-
const inputShape = [batchSize, frames, 161];
76-
const input = builder.input(
77-
'input', {type: 'float32', dimensions: inputShape});
78-
const matmul18 = builder.matmul(input, weight172);
79-
const add19 = builder.add(matmul18, biasFcIn0);
80-
const relu20 = builder.relu(add19);
57+
// Create constants by loading pre-trained data from .npy files.
58+
const weight172 = await buildConstantByNpy(builder, baseUrl + '172.npy');
59+
const biasFcIn0 = await buildConstantByNpy(builder, baseUrl + 'fc_in_0_bias.npy');
60+
const weight192 = await buildConstantByNpy(builder, baseUrl + '192.npy');
61+
const recurrentWeight193 = await buildConstantByNpy(builder, baseUrl + '193.npy');
62+
const bias194 = await buildConstantByNpy(builder, baseUrl + '194_0.npy');
63+
const recurrentBias194 = await buildConstantByNpy(builder, baseUrl + '194_1.npy');
64+
const weight212 = await buildConstantByNpy(builder, baseUrl + '212.npy');
65+
const recurrentWeight213 = await buildConstantByNpy(builder, baseUrl + '213.npy');
66+
const bias214 = await buildConstantByNpy(builder, baseUrl + '214_0.npy');
67+
const recurrentBias214 = await buildConstantByNpy(builder, baseUrl + '214_1.npy');
68+
const weight215 = await buildConstantByNpy(builder, baseUrl + '215.npy');
69+
const biasFcOut0 = await buildConstantByNpy(builder, baseUrl + 'fc_out_0_bias.npy');
70+
const weight216 = await buildConstantByNpy(builder, baseUrl + '216.npy');
71+
const biasFcOut2 = await buildConstantByNpy(builder, baseUrl + 'fc_out_2_bias.npy');
72+
const weight217 = await buildConstantByNpy(builder, baseUrl + '217.npy');
73+
const biasFcOut4 = await buildConstantByNpy(builder, baseUrl + 'fc_out_4_bias.npy');
74+
// Build up the network.
75+
const input = builder.input('input', {type: 'float32', dimensions: [batchSize, frames, this.frameSize]});
76+
const relu20 = builder.relu(builder.add(builder.matmul(input, weight172), biasFcIn0));
8177
const transpose31 = builder.transpose(relu20, {permutation: [1, 0, 2]});
82-
const bias194 = builder.slice(data194, [0], [3 * hiddenSize], {axes: [1]});
83-
const recurrentBias194 = builder.slice(
84-
data194, [3 * hiddenSize], [-1], {axes: [1]});
85-
const initialHiddenState92 = builder.input(
86-
'initialHiddenState92',
87-
{type: 'float32', dimensions: [1, batchSize, hiddenSize]});
88-
const [gru94, gru93] = builder.gru(
89-
transpose31, weight192, recurrentWeight193, frames, hiddenSize,
90-
{
91-
bias: bias194, recurrentBias: recurrentBias194,
92-
initialHiddenState: initialHiddenState92, returnSequence: true,
93-
});
78+
const initialState92 = builder.input('initialState92', {type: 'float32', dimensions: [1, batchSize, this.hiddenSize]});
79+
const [gru94, gru93] = builder.gru(transpose31, weight192, recurrentWeight193, frames, this.hiddenSize,
80+
{bias: bias194, recurrentBias: recurrentBias194, initialHiddenState: initialState92, returnSequence: true});
9481
const squeeze95 = builder.squeeze(gru93, {axes: [1]});
95-
const bias214 = builder.slice(data214, [0], [3 * hiddenSize], {axes: [1]});
96-
const recurrentBias214 = builder.slice(
97-
data214, [3 * hiddenSize], [-1], {axes: [1]});
98-
const initialHiddenState155 = builder.input(
99-
'initialHiddenState155',
100-
{type: 'float32', dimensions: [1, batchSize, hiddenSize]});
101-
const [gru157, gru156] = builder.gru(
102-
squeeze95, weight212, recurrentWeight213, frames, hiddenSize,
103-
{
104-
bias: bias214, recurrentBias: recurrentBias214,
105-
initialHiddenState: initialHiddenState155, returnSequence: true,
106-
});
82+
const initialState155 = builder.input('initialState155', {type: 'float32', dimensions: [1, batchSize, this.hiddenSize]});
83+
const [gru157, gru156] = builder.gru(squeeze95, weight212, recurrentWeight213, frames, this.hiddenSize,
84+
{bias: bias214, recurrentBias: recurrentBias214, initialHiddenState: initialState155, returnSequence: true});
10785
const squeeze158 = builder.squeeze(gru156, {axes: [1]});
108-
const transpose159 = builder.transpose(
109-
squeeze158, {permutation: [1, 0, 2]});
110-
const matmul161 = builder.matmul(transpose159, weight215);
111-
const add162 = builder.add(matmul161, biasFcOut0);
112-
const relu163 = builder.relu(add162);
113-
const matmul165 = builder.matmul(relu163, weight216);
114-
const add166 = builder.add(matmul165, biasFcOut2);
115-
const relu167 = builder.relu(add166);
116-
const matmul169 = builder.matmul(relu167, weight217);
117-
const add170 = builder.add(matmul169, biasFcOut4);
118-
const output = builder.sigmoid(add170);
119-
this.model_ = builder.createModel({output, gru94, gru157});
86+
const transpose159 = builder.transpose(squeeze158, {permutation: [1, 0, 2]});
87+
const relu163 = builder.relu(builder.add(builder.matmul(transpose159, weight215), biasFcOut0));
88+
const relu167 = builder.relu(builder.add(builder.matmul(relu163, weight216), biasFcOut2));
89+
const output = builder.sigmoid(builder.add(builder.matmul(relu167, weight217), biasFcOut4));
90+
this.model = builder.createModel({output, gru94, gru157});
12091
}
12192

12293
async compile(options) {
123-
this.compilation_ = await this.model_.compile(options);
94+
this.compiledModel = await this.model.compile(options);
12495
}
12596

126-
async compute(
127-
inputBuffer, initialHiddenState92Buffer, initialHiddenState155Buffer) {
97+
async compute(inputBuffer, initialState92Buffer, initialState155Buffer) {
12898
const inputs = {
12999
input: {buffer: inputBuffer},
130-
initialHiddenState92: {buffer: initialHiddenState92Buffer},
131-
initialHiddenState155: {buffer: initialHiddenState155Buffer},
100+
initialState92: {buffer: initialState92Buffer},
101+
initialState155: {buffer: initialState155Buffer},
132102
};
133-
return await this.compilation_.compute(inputs);
134-
}
135-
136-
dispose() {
137-
this.compilation_.dispose();
103+
return await this.compiledModel.compute(inputs);
138104
}
139105
}

nsnet2/weights/194.npy

-9.45 KB
Binary file not shown.

nsnet2/weights/194_0.npy

4.81 KB
Binary file not shown.

nsnet2/weights/194_1.npy

4.81 KB
Binary file not shown.

nsnet2/weights/214.npy

-9.45 KB
Binary file not shown.

nsnet2/weights/214_0.npy

4.81 KB
Binary file not shown.

nsnet2/weights/214_1.npy

4.81 KB
Binary file not shown.

0 commit comments

Comments
 (0)