11import Layer from '../../Layer'
22import Tensor from '../../Tensor'
3+ import { webgl2 } from '../../WebGL2'
34import ops from 'ndarray-ops'
5+ import * as layers from '../'
46
57/**
68 * TimeDistributed wrapper layer class
@@ -17,8 +19,24 @@ export default class TimeDistributed extends Layer {
1719
1820 const { layer } = attrs
1921
20- if ( ! layer ) this . throwError ( 'wrapped layer is undefined.' )
21- this . layer = layer
22+ if ( ! layer ) {
23+ this . throwError ( 'wrapped layer is undefined.' )
24+ }
25+
26+ const wrappedLayerAttrs = Object . assign ( { } , layer . config , { gpu : attrs . gpu } )
27+ this . wrappedLayer = new layers [ layer . class_name ] ( wrappedLayerAttrs )
28+
29+ // prevent GPU -> CPU data transfer by specifying non-empty outbound nodes array on internal layer
30+ this . wrappedLayer . outbound = [ null ]
31+
32+ // GPU setup
33+ if ( this . gpu ) {
34+ this . copyTextureProgram = webgl2 . compileProgram ( require ( '../../copyTexture.glsl' ) )
35+ this . mapInputProgram = webgl2 . compileProgram ( require ( '../../mapInput.glsl' ) )
36+ this . selectSliceProgram = webgl2 . compileProgram ( require ( './TimeDistributed.selectSlice.glsl' ) )
37+ this . copySliceOutputProgram = webgl2 . compileProgram ( require ( './TimeDistributed.copySliceOutput.glsl' ) )
38+ this . mapSliceOutputProgram = webgl2 . compileProgram ( require ( './TimeDistributed.mapSliceOutput.glsl' ) )
39+ }
2240 }
2341
2442 /**
@@ -28,32 +46,282 @@ export default class TimeDistributed extends Layer {
2846 * @param {Tensor[] } weightsArr - array of weights which are instances of Tensor
2947 */
3048 setWeights ( weightsArr ) {
31- this . layer . setWeights ( weightsArr )
49+ this . wrappedLayer . setWeights ( weightsArr )
3250 }
3351
3452 /**
35- * Method for layer computational logic
53+ * Layer computational logic
3654 *
3755 * @param {Tensor } x
3856 * @returns {Tensor }
3957 */
4058 call ( x ) {
41- const xStepShape = [ ...x . tensor . shape . slice ( 1 ) ]
42- let xStep = new Tensor ( [ ] , xStepShape )
43- ops . assign ( xStep . tensor , x . tensor . pick ( 0 , ...xStepShape . map ( s => null ) ) )
44- let yStep = this . layer . call ( xStep )
45- const yStepShape = yStep . tensor . shape . slice ( )
46- let y = new Tensor ( [ ] , [ x . tensor . shape [ 0 ] , ...yStepShape ] )
47- ops . assign ( y . tensor . pick ( 0 , ...yStepShape . map ( s => null ) ) , yStep . tensor )
48-
49- for ( let i = 1 , steps = x . tensor . shape [ 0 ] ; i < steps ; i ++ ) {
50- let xStep = new Tensor ( [ ] , xStepShape )
51- ops . assign ( xStep . tensor , x . tensor . pick ( i , ...xStepShape . map ( s => null ) ) )
52- yStep = this . layer . call ( xStep )
53- ops . assign ( y . tensor . pick ( i , ...yStepShape . map ( s => null ) ) , yStep . tensor )
54- }
55-
56- x . tensor = y . tensor
57- return x
59+ if ( this . gpu ) {
60+ this . _callGPU ( x )
61+ } else {
62+ this . _callCPU ( x )
63+ }
64+ return this . output
65+ }
66+
67+ /**
68+ * CPU call
69+ *
70+ * @param {Tensor } x
71+ */
72+ _callCPU ( x ) {
73+ const stepShape = [ ...x . tensor . shape . slice ( 1 ) ]
74+ const step = new Tensor ( [ ] , stepShape )
75+ ops . assign ( step . tensor , x . tensor . pick ( 0 , ...Array ( stepShape . length ) . fill ( null ) ) )
76+ let stepOutput = this . wrappedLayer . call ( step )
77+ const stepOutputShape = stepOutput . tensor . shape . slice ( )
78+ this . output = new Tensor ( [ ] , [ x . tensor . shape [ 0 ] , ...stepOutputShape ] )
79+ ops . assign ( this . output . tensor . pick ( 0 , ...Array ( stepOutputShape . length ) . fill ( null ) ) , stepOutput . tensor )
80+ for ( let i = 1 , timesteps = x . tensor . shape [ 0 ] ; i < timesteps ; i ++ ) {
81+ ops . assign ( step . tensor , x . tensor . pick ( i , ...Array ( stepShape . length ) . fill ( null ) ) )
82+ stepOutput = this . wrappedLayer . call ( step )
83+ ops . assign ( this . output . tensor . pick ( i , ...Array ( stepOutputShape . length ) . fill ( null ) ) , stepOutput . tensor )
84+ }
85+ }
86+
87+ /**
88+ * Creates row/col index mappings to map input texture to time-distributed slices
89+ *
90+ * @param {Object } indicesForReshaped
91+ */
92+ _createIndexMap ( indicesForReshaped ) {
93+ if ( this . rowIndexMaps && this . colIndexMaps ) {
94+ return
95+ }
96+
97+ const indicesRow = new Tensor ( indicesForReshaped . row . data , indicesForReshaped . row . shape , { type : Int32Array } )
98+ const indicesCol = new Tensor ( indicesForReshaped . col . data , indicesForReshaped . col . shape , { type : Int32Array } )
99+
100+ this . rowIndexMaps = [ ]
101+ this . colIndexMaps = [ ]
102+
103+ const timesteps = this . inputShape [ 0 ]
104+ const sliceShape = this . inputShape . slice ( 1 )
105+ for ( let t = 0 ; t < timesteps ; t ++ ) {
106+ const sliceIndicesRow = new Tensor ( [ ] , sliceShape , { type : Int32Array } )
107+ const sliceIndicesCol = new Tensor ( [ ] , sliceShape , { type : Int32Array } )
108+ ops . assign ( sliceIndicesRow . tensor , indicesRow . tensor . pick ( t , ...Array ( sliceShape . length ) . fill ( null ) ) )
109+ ops . assign ( sliceIndicesCol . tensor , indicesCol . tensor . pick ( t , ...Array ( sliceShape . length ) . fill ( null ) ) )
110+ sliceIndicesRow . reshapeTo2DSquare ( )
111+ sliceIndicesCol . reshapeTo2DSquare ( )
112+ sliceIndicesRow . createGLTexture ( '2d' , 'int' )
113+ sliceIndicesCol . createGLTexture ( '2d' , 'int' )
114+ this . rowIndexMaps . push ( sliceIndicesRow )
115+ this . colIndexMaps . push ( sliceIndicesCol )
116+ }
117+ }
118+
119+ /**
120+ * Creates row/col index mappings to map time-distributed slices to output texture
121+ *
122+ * @param {Object } indicesForReshaped
123+ */
124+ _createOutputIndexMap ( indicesForReshaped ) {
125+ if ( this . outputRowIndexMaps && this . outputColIndexMaps ) {
126+ return
127+ }
128+
129+ const outputSliceIndicesRow = new Tensor ( indicesForReshaped . row . data , indicesForReshaped . row . shape , {
130+ type : Int32Array
131+ } )
132+ const outputSliceIndicesCol = new Tensor ( indicesForReshaped . col . data , indicesForReshaped . col . shape , {
133+ type : Int32Array
134+ } )
135+
136+ this . outputRowIndexMaps = [ ]
137+ this . outputColIndexMaps = [ ]
138+
139+ const timesteps = this . outputShape [ 0 ]
140+ const sliceShape = this . outputShape . slice ( 1 )
141+ for ( let t = 0 ; t < timesteps ; t ++ ) {
142+ const outputIndicesRow = new Tensor ( [ ] , this . outputShape , { type : Int32Array } )
143+ const outputIndicesCol = new Tensor ( [ ] , this . outputShape , { type : Int32Array } )
144+ ops . assigns ( outputIndicesRow . tensor , - 1 )
145+ ops . assigns ( outputIndicesCol . tensor , - 1 )
146+ ops . assign ( outputIndicesRow . tensor . pick ( t , ...Array ( sliceShape . length ) . fill ( null ) ) , outputSliceIndicesRow . tensor )
147+ ops . assign ( outputIndicesCol . tensor . pick ( t , ...Array ( sliceShape . length ) . fill ( null ) ) , outputSliceIndicesCol . tensor )
148+ outputIndicesRow . reshapeTo2DSquare ( )
149+ outputIndicesCol . reshapeTo2DSquare ( )
150+ outputIndicesRow . createGLTexture ( '2d' , 'int' )
151+ outputIndicesCol . createGLTexture ( '2d' , 'int' )
152+ this . outputRowIndexMaps . push ( outputIndicesRow )
153+ this . outputColIndexMaps . push ( outputIndicesCol )
154+ }
155+ }
156+
157+ /**
158+ * GPU call
159+ *
160+ * @param {Tensor } x
161+ */
162+ _callGPU ( x ) {
163+ if ( x . is2DReshaped ) {
164+ this . inputShape = x . originalShape
165+ } else {
166+ this . inputShape = x . tensor . shape
167+ }
168+
169+ if ( ! x . glTexture ) {
170+ if ( x . tensor . shape . length <= 2 ) {
171+ x . createGLTexture ( )
172+ } else if ( x . tensor . shape . length > 2 && ! x . is2DReshaped ) {
173+ x . reshapeTo2DSquare ( )
174+ x . createGLTexture ( )
175+ }
176+ }
177+
178+ if ( this . inputShape . length > 2 ) {
179+ this . _createIndexMap ( x . indicesForReshaped )
180+ }
181+
182+ const timesteps = this . inputShape [ 0 ]
183+ const sliceShape = this . inputShape . slice ( 1 )
184+
185+ if ( ! this . slice ) {
186+ this . slice = new Tensor ( [ ] , sliceShape )
187+ if ( sliceShape . length <= 2 ) {
188+ this . slice . createGLTexture ( )
189+ } else {
190+ this . slice . reshapeTo2DSquare ( )
191+ this . slice . createGLTexture ( )
192+ }
193+ }
194+
195+ if ( this . inputShape . length <= 2 ) {
196+ webgl2 . runProgram ( {
197+ program : this . selectSliceProgram ,
198+ output : this . slice ,
199+ inputs : [ { texture : x . glTexture , type : '2d' , name : 'x' } ] ,
200+ uniforms : [ { value : 0 , type : 'int' , name : 't' } ]
201+ } )
202+ } else {
203+ webgl2 . runProgram ( {
204+ program : this . mapInputProgram ,
205+ output : this . slice ,
206+ inputs : [
207+ { texture : x . glTexture , type : '2d' , name : 'x' } ,
208+ { texture : this . rowIndexMaps [ 0 ] . glTexture , type : '2d' , name : 'rowIndexMap' } ,
209+ { texture : this . colIndexMaps [ 0 ] . glTexture , type : '2d' , name : 'colIndexMap' }
210+ ]
211+ } )
212+ }
213+
214+ this . wrappedLayer . _callGPU ( this . slice )
215+ this . sliceOutput = this . wrappedLayer . output
216+
217+ if ( ! this . output ) {
218+ if ( this . inputShape . length <= 2 ) {
219+ this . outputShape = [ timesteps , this . sliceOutput . glTextureShape [ 1 ] ]
220+ this . output = new Tensor ( [ ] , this . outputShape )
221+ this . outputCopy = new Tensor ( [ ] , this . outputShape )
222+ this . output . createGLTexture ( )
223+ this . outputCopy . createGLTexture ( )
224+ } else {
225+ this . outputShape = [ timesteps , ...this . sliceOutput . originalShape ]
226+ this . output = new Tensor ( [ ] , this . outputShape )
227+ this . outputCopy = new Tensor ( [ ] , this . outputShape )
228+ this . output . reshapeTo2DSquare ( )
229+ this . outputCopy . reshapeTo2DSquare ( )
230+ this . output . createGLTexture ( )
231+ this . outputCopy . createGLTexture ( )
232+
233+ this . _createOutputIndexMap ( this . sliceOutput . indicesForReshaped )
234+ }
235+ }
236+
237+ webgl2 . runProgram ( {
238+ program : this . copyTextureProgram ,
239+ output : this . outputCopy ,
240+ inputs : [ { texture : this . output . glTexture , type : '2d' , name : 'source' } ]
241+ } )
242+
243+ if ( this . inputShape . length <= 2 ) {
244+ webgl2 . runProgram ( {
245+ program : this . copySliceOutputProgram ,
246+ output : this . output ,
247+ inputs : [
248+ { texture : this . outputCopy . glTexture , type : '2d' , name : 'outputCopy' } ,
249+ { texture : this . sliceOutput . glTexture , type : '2d' , name : 'sliceOutput' }
250+ ] ,
251+ uniforms : [ { value : 0 , type : 'int' , name : 't' } , { value : timesteps , type : 'int' , name : 'timesteps' } ]
252+ } )
253+ } else {
254+ webgl2 . runProgram ( {
255+ program : this . mapSliceOutputProgram ,
256+ output : this . output ,
257+ inputs : [
258+ { texture : this . outputCopy . glTexture , type : '2d' , name : 'outputCopy' } ,
259+ { texture : this . sliceOutput . glTexture , type : '2d' , name : 'sliceOutput' } ,
260+ { texture : this . outputRowIndexMaps [ 0 ] . glTexture , type : '2d' , name : 'rowIndexMap' } ,
261+ { texture : this . outputColIndexMaps [ 0 ] . glTexture , type : '2d' , name : 'colIndexMap' }
262+ ]
263+ } )
264+ }
265+
266+ for ( let i = 1 ; i < timesteps ; i ++ ) {
267+ if ( this . inputShape . length <= 2 ) {
268+ webgl2 . runProgram ( {
269+ program : this . selectSliceProgram ,
270+ output : this . slice ,
271+ inputs : [ { texture : x . glTexture , type : '2d' , name : 'x' } ] ,
272+ uniforms : [ { value : i , type : 'int' , name : 't' } ]
273+ } )
274+ } else {
275+ webgl2 . runProgram ( {
276+ program : this . mapInputProgram ,
277+ output : this . slice ,
278+ inputs : [
279+ { texture : x . glTexture , type : '2d' , name : 'x' } ,
280+ { texture : this . rowIndexMaps [ i ] . glTexture , type : '2d' , name : 'rowIndexMap' } ,
281+ { texture : this . colIndexMaps [ i ] . glTexture , type : '2d' , name : 'colIndexMap' }
282+ ]
283+ } )
284+ }
285+
286+ this . wrappedLayer . _callGPU ( this . slice )
287+ this . sliceOutput = this . wrappedLayer . output
288+
289+ webgl2 . runProgram ( {
290+ program : this . copyTextureProgram ,
291+ output : this . outputCopy ,
292+ inputs : [ { texture : this . output . glTexture , type : '2d' , name : 'source' } ]
293+ } )
294+
295+ if ( this . inputShape . length <= 2 ) {
296+ webgl2 . runProgram ( {
297+ program : this . copySliceOutputProgram ,
298+ output : this . output ,
299+ inputs : [
300+ { texture : this . outputCopy . glTexture , type : '2d' , name : 'outputCopy' } ,
301+ { texture : this . sliceOutput . glTexture , type : '2d' , name : 'sliceOutput' }
302+ ] ,
303+ uniforms : [ { value : i , type : 'int' , name : 't' } , { value : timesteps , type : 'int' , name : 'timesteps' } ]
304+ } )
305+ } else {
306+ webgl2 . runProgram ( {
307+ program : this . mapSliceOutputProgram ,
308+ output : this . output ,
309+ inputs : [
310+ { texture : this . outputCopy . glTexture , type : '2d' , name : 'outputCopy' } ,
311+ { texture : this . sliceOutput . glTexture , type : '2d' , name : 'sliceOutput' } ,
312+ { texture : this . outputRowIndexMaps [ i ] . glTexture , type : '2d' , name : 'rowIndexMap' } ,
313+ { texture : this . outputColIndexMaps [ i ] . glTexture , type : '2d' , name : 'colIndexMap' }
314+ ]
315+ } )
316+ }
317+ }
318+
319+ // GPU -> CPU data transfer
320+ if ( this . outbound . length === 0 ) {
321+ this . output . transferFromGLTexture ( )
322+ if ( this . output . is2DReshaped ) {
323+ this . output . reshapeFrom2DSquare ( )
324+ }
325+ }
58326 }
59327}
0 commit comments