@@ -7,133 +7,99 @@ function sizeOfShape(shape) {
7
7
} ) ;
8
8
}
9
9
10
+ async function buildConstantByNpy ( builder , url ) {
11
+ const dataTypeMap = new Map ( [
12
+ [ 'f2' , { type : 'float16' , array : Uint16Array } ] ,
13
+ [ 'f4' , { type : 'float32' , array : Float32Array } ] ,
14
+ [ 'f8' , { type : 'float64' , array : Float64Array } ] ,
15
+ [ 'i1' , { type : 'int8' , array : Int8Array } ] ,
16
+ [ 'i2' , { type : 'int16' , array : Int16Array } ] ,
17
+ [ 'i4' , { type : 'int32' , array : Int32Array } ] ,
18
+ [ 'i8' , { type : 'int64' , array : BigInt64Array } ] ,
19
+ [ 'u1' , { type : 'uint8' , array : Uint8Array } ] ,
20
+ [ 'u2' , { type : 'uint16' , array : Uint16Array } ] ,
21
+ [ 'u4' , { type : 'uint32' , array : Uint32Array } ] ,
22
+ [ 'u8' , { type : 'uint64' , array : BigUint64Array } ] ,
23
+ ] ) ;
24
+ const response = await fetch ( url ) ;
25
+ const buffer = await response . arrayBuffer ( ) ;
26
+ const npArray = new numpy . Array ( new Uint8Array ( buffer ) ) ;
27
+ if ( ! dataTypeMap . has ( npArray . dataType ) ) {
28
+ throw new Error ( `Data type ${ npArray . dataType } is not supported.` ) ;
29
+ }
30
+ const dimensions = npArray . shape ;
31
+ const type = dataTypeMap . get ( npArray . dataType ) . type ;
32
+ const TypedArrayConstructor = dataTypeMap . get ( npArray . dataType ) . array ;
33
+ const typedArray = new TypedArrayConstructor ( sizeOfShape ( dimensions ) ) ;
34
+ const dataView = new DataView ( npArray . data . buffer ) ;
35
+ const littleEndian = npArray . byteOrder === '<' ;
36
+ for ( let i = 0 ; i < sizeOfShape ( dimensions ) ; ++ i ) {
37
+ typedArray [ i ] = dataView [ `get` + type [ 0 ] . toUpperCase ( ) + type . substr ( 1 ) ] (
38
+ i * TypedArrayConstructor . BYTES_PER_ELEMENT , littleEndian ) ;
39
+ }
40
+ return builder . constant ( { type, dimensions} , typedArray ) ;
41
+ }
42
+
43
+ /* eslint max-len: ["error", { "code": 130 }] */
44
+
45
+ // Noise Suppression Net 2 (NSNet2) Baseline Model for Deep Noise Suppression Challenge (DNS) 2021.
10
46
export class NSNet2 {
11
47
constructor ( ) {
12
- this . baseUrl_ = './' ;
13
- this . model_ = null ;
14
- this . compilation_ = null ;
48
+ this . model = null ;
49
+ this . compiledModel = null ;
50
+ this . frameSize = 161 ;
15
51
this . hiddenSize = 400 ;
16
52
}
17
53
18
- async buildConstantByNpy ( fileName ) {
19
- const dataTypeMap = new Map ( [
20
- [ 'f2' , { type : 'float16' , array : Uint16Array } ] ,
21
- [ 'f4' , { type : 'float32' , array : Float32Array } ] ,
22
- [ 'f8' , { type : 'float64' , array : Float64Array } ] ,
23
- [ 'i1' , { type : 'int8' , array : Int8Array } ] ,
24
- [ 'i2' , { type : 'int16' , array : Int16Array } ] ,
25
- [ 'i4' , { type : 'int32' , array : Int32Array } ] ,
26
- [ 'i8' , { type : 'int64' , array : BigInt64Array } ] ,
27
- [ 'u1' , { type : 'uint8' , array : Uint8Array } ] ,
28
- [ 'u2' , { type : 'uint16' , array : Uint16Array } ] ,
29
- [ 'u4' , { type : 'uint32' , array : Uint32Array } ] ,
30
- [ 'u8' , { type : 'uint64' , array : BigUint64Array } ] ,
31
- ] ) ;
32
- const response = await fetch ( this . baseUrl_ + fileName ) ;
33
- const buffer = await response . arrayBuffer ( ) ;
34
- const npArray = new numpy . Array ( new Uint8Array ( buffer ) ) ;
35
- if ( ! dataTypeMap . has ( npArray . dataType ) ) {
36
- throw new Error ( `Data type ${ npArray . dataType } is not supported.` ) ;
37
- }
38
- const dimensions = npArray . shape ;
39
- const type = dataTypeMap . get ( npArray . dataType ) . type ;
40
- const TypedArrayConstructor = dataTypeMap . get ( npArray . dataType ) . array ;
41
- const typedArray = new TypedArrayConstructor ( sizeOfShape ( dimensions ) ) ;
42
- const dataView = new DataView ( npArray . data . buffer ) ;
43
- const littleEndian = npArray . byteOrder === '<' ;
44
- for ( let i = 0 ; i < sizeOfShape ( dimensions ) ; ++ i ) {
45
- typedArray [ i ] = dataView [ `get` + type [ 0 ] . toUpperCase ( ) + type . substr ( 1 ) ] (
46
- i * TypedArrayConstructor . BYTES_PER_ELEMENT , littleEndian ) ;
47
- }
48
- return this . builder . constant ( { type, dimensions} , typedArray ) ;
49
- }
50
-
51
- async load ( url , batchSize , frames ) {
52
- this . baseUrl_ = url ;
54
+ async load ( baseUrl , batchSize , frames ) {
53
55
const nn = navigator . ml . getNeuralNetworkContext ( ) ;
54
56
const builder = nn . createModelBuilder ( ) ;
55
- this . builder = builder ;
56
-
57
- // Create constants
58
- const weight172 = await this . buildConstantByNpy ( '172.npy' ) ;
59
- const biasFcIn0 = await this . buildConstantByNpy ( 'fc_in_0_bias.npy' ) ;
60
- const weight192 = await this . buildConstantByNpy ( '192.npy' ) ;
61
- const recurrentWeight193 = await this . buildConstantByNpy ( '193.npy' ) ;
62
- const data194 = await this . buildConstantByNpy ( '194.npy' ) ;
63
- const weight212 = await this . buildConstantByNpy ( '212.npy' ) ;
64
- const recurrentWeight213 = await this . buildConstantByNpy ( '213.npy' ) ;
65
- const data214 = await this . buildConstantByNpy ( '214.npy' ) ;
66
- const weight215 = await this . buildConstantByNpy ( '215.npy' ) ;
67
- const biasFcOut0 = await this . buildConstantByNpy ( 'fc_out_0_bias.npy' ) ;
68
- const weight216 = await this . buildConstantByNpy ( '216.npy' ) ;
69
- const biasFcOut2 = await this . buildConstantByNpy ( 'fc_out_2_bias.npy' ) ;
70
- const weight217 = await this . buildConstantByNpy ( '217.npy' ) ;
71
- const biasFcOut4 = await this . buildConstantByNpy ( 'fc_out_4_bias.npy' ) ;
72
-
73
- // Build up the network
74
- const hiddenSize = this . hiddenSize ;
75
- const inputShape = [ batchSize , frames , 161 ] ;
76
- const input = builder . input (
77
- 'input' , { type : 'float32' , dimensions : inputShape } ) ;
78
- const matmul18 = builder . matmul ( input , weight172 ) ;
79
- const add19 = builder . add ( matmul18 , biasFcIn0 ) ;
80
- const relu20 = builder . relu ( add19 ) ;
57
+ // Create constants by loading pre-trained data from .npy files.
58
+ const weight172 = await buildConstantByNpy ( builder , baseUrl + '172.npy' ) ;
59
+ const biasFcIn0 = await buildConstantByNpy ( builder , baseUrl + 'fc_in_0_bias.npy' ) ;
60
+ const weight192 = await buildConstantByNpy ( builder , baseUrl + '192.npy' ) ;
61
+ const recurrentWeight193 = await buildConstantByNpy ( builder , baseUrl + '193.npy' ) ;
62
+ const bias194 = await buildConstantByNpy ( builder , baseUrl + '194_0.npy' ) ;
63
+ const recurrentBias194 = await buildConstantByNpy ( builder , baseUrl + '194_1.npy' ) ;
64
+ const weight212 = await buildConstantByNpy ( builder , baseUrl + '212.npy' ) ;
65
+ const recurrentWeight213 = await buildConstantByNpy ( builder , baseUrl + '213.npy' ) ;
66
+ const bias214 = await buildConstantByNpy ( builder , baseUrl + '214_0.npy' ) ;
67
+ const recurrentBias214 = await buildConstantByNpy ( builder , baseUrl + '214_1.npy' ) ;
68
+ const weight215 = await buildConstantByNpy ( builder , baseUrl + '215.npy' ) ;
69
+ const biasFcOut0 = await buildConstantByNpy ( builder , baseUrl + 'fc_out_0_bias.npy' ) ;
70
+ const weight216 = await buildConstantByNpy ( builder , baseUrl + '216.npy' ) ;
71
+ const biasFcOut2 = await buildConstantByNpy ( builder , baseUrl + 'fc_out_2_bias.npy' ) ;
72
+ const weight217 = await buildConstantByNpy ( builder , baseUrl + '217.npy' ) ;
73
+ const biasFcOut4 = await buildConstantByNpy ( builder , baseUrl + 'fc_out_4_bias.npy' ) ;
74
+ // Build up the network.
75
+ const input = builder . input ( 'input' , { type : 'float32' , dimensions : [ batchSize , frames , this . frameSize ] } ) ;
76
+ const relu20 = builder . relu ( builder . add ( builder . matmul ( input , weight172 ) , biasFcIn0 ) ) ;
81
77
const transpose31 = builder . transpose ( relu20 , { permutation : [ 1 , 0 , 2 ] } ) ;
82
- const bias194 = builder . slice ( data194 , [ 0 ] , [ 3 * hiddenSize ] , { axes : [ 1 ] } ) ;
83
- const recurrentBias194 = builder . slice (
84
- data194 , [ 3 * hiddenSize ] , [ - 1 ] , { axes : [ 1 ] } ) ;
85
- const initialHiddenState92 = builder . input (
86
- 'initialHiddenState92' ,
87
- { type : 'float32' , dimensions : [ 1 , batchSize , hiddenSize ] } ) ;
88
- const [ gru94 , gru93 ] = builder . gru (
89
- transpose31 , weight192 , recurrentWeight193 , frames , hiddenSize ,
90
- {
91
- bias : bias194 , recurrentBias : recurrentBias194 ,
92
- initialHiddenState : initialHiddenState92 , returnSequence : true ,
93
- } ) ;
78
+ const initialState92 = builder . input ( 'initialState92' , { type : 'float32' , dimensions : [ 1 , batchSize , this . hiddenSize ] } ) ;
79
+ const [ gru94 , gru93 ] = builder . gru ( transpose31 , weight192 , recurrentWeight193 , frames , this . hiddenSize ,
80
+ { bias : bias194 , recurrentBias : recurrentBias194 , initialHiddenState : initialState92 , returnSequence : true } ) ;
94
81
const squeeze95 = builder . squeeze ( gru93 , { axes : [ 1 ] } ) ;
95
- const bias214 = builder . slice ( data214 , [ 0 ] , [ 3 * hiddenSize ] , { axes : [ 1 ] } ) ;
96
- const recurrentBias214 = builder . slice (
97
- data214 , [ 3 * hiddenSize ] , [ - 1 ] , { axes : [ 1 ] } ) ;
98
- const initialHiddenState155 = builder . input (
99
- 'initialHiddenState155' ,
100
- { type : 'float32' , dimensions : [ 1 , batchSize , hiddenSize ] } ) ;
101
- const [ gru157 , gru156 ] = builder . gru (
102
- squeeze95 , weight212 , recurrentWeight213 , frames , hiddenSize ,
103
- {
104
- bias : bias214 , recurrentBias : recurrentBias214 ,
105
- initialHiddenState : initialHiddenState155 , returnSequence : true ,
106
- } ) ;
82
+ const initialState155 = builder . input ( 'initialState155' , { type : 'float32' , dimensions : [ 1 , batchSize , this . hiddenSize ] } ) ;
83
+ const [ gru157 , gru156 ] = builder . gru ( squeeze95 , weight212 , recurrentWeight213 , frames , this . hiddenSize ,
84
+ { bias : bias214 , recurrentBias : recurrentBias214 , initialHiddenState : initialState155 , returnSequence : true } ) ;
107
85
const squeeze158 = builder . squeeze ( gru156 , { axes : [ 1 ] } ) ;
108
- const transpose159 = builder . transpose (
109
- squeeze158 , { permutation : [ 1 , 0 , 2 ] } ) ;
110
- const matmul161 = builder . matmul ( transpose159 , weight215 ) ;
111
- const add162 = builder . add ( matmul161 , biasFcOut0 ) ;
112
- const relu163 = builder . relu ( add162 ) ;
113
- const matmul165 = builder . matmul ( relu163 , weight216 ) ;
114
- const add166 = builder . add ( matmul165 , biasFcOut2 ) ;
115
- const relu167 = builder . relu ( add166 ) ;
116
- const matmul169 = builder . matmul ( relu167 , weight217 ) ;
117
- const add170 = builder . add ( matmul169 , biasFcOut4 ) ;
118
- const output = builder . sigmoid ( add170 ) ;
119
- this . model_ = builder . createModel ( { output, gru94, gru157} ) ;
86
+ const transpose159 = builder . transpose ( squeeze158 , { permutation : [ 1 , 0 , 2 ] } ) ;
87
+ const relu163 = builder . relu ( builder . add ( builder . matmul ( transpose159 , weight215 ) , biasFcOut0 ) ) ;
88
+ const relu167 = builder . relu ( builder . add ( builder . matmul ( relu163 , weight216 ) , biasFcOut2 ) ) ;
89
+ const output = builder . sigmoid ( builder . add ( builder . matmul ( relu167 , weight217 ) , biasFcOut4 ) ) ;
90
+ this . model = builder . createModel ( { output, gru94, gru157} ) ;
120
91
}
121
92
122
93
async compile ( options ) {
123
- this . compilation_ = await this . model_ . compile ( options ) ;
94
+ this . compiledModel = await this . model . compile ( options ) ;
124
95
}
125
96
126
- async compute (
127
- inputBuffer , initialHiddenState92Buffer , initialHiddenState155Buffer ) {
97
+ async compute ( inputBuffer , initialState92Buffer , initialState155Buffer ) {
128
98
const inputs = {
129
99
input : { buffer : inputBuffer } ,
130
- initialHiddenState92 : { buffer : initialHiddenState92Buffer } ,
131
- initialHiddenState155 : { buffer : initialHiddenState155Buffer } ,
100
+ initialState92 : { buffer : initialState92Buffer } ,
101
+ initialState155 : { buffer : initialState155Buffer } ,
132
102
} ;
133
- return await this . compilation_ . compute ( inputs ) ;
134
- }
135
-
136
- dispose ( ) {
137
- this . compilation_ . dispose ( ) ;
103
+ return await this . compiledModel . compute ( inputs ) ;
138
104
}
139
105
}
0 commit comments