Skip to content

Commit 6413331

Browse files
committed
add gpu call method for TimeDistributed wrapper layer
1 parent 86d071c commit 6413331

File tree

9 files changed

+406
-111
lines changed

9 files changed

+406
-111
lines changed

notebooks/layers/wrappers/TimeDistributed.ipynb

Lines changed: 20 additions & 70 deletions
Large diffs are not rendered by default.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
"postcss-loader": "^2.0.8",
7272
"raw-loader": "^0.5.1",
7373
"vue": "^2.5.3",
74-
"vue-loader": "^13.4.0",
74+
"vue-loader": "^13.5.0",
7575
"vue-mdl": "^1.1.1",
7676
"vue-router": "^3.0.1",
7777
"vue-template-compiler": "^2.5.3",

src/layers/wrappers/Bidirectional.js

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ export default class Bidirectional extends Layer {
3636
this.forwardLayer = new recurrentLayers[layer.class_name](forwardLayerAttrs)
3737
this.backwardLayer = new recurrentLayers[layer.class_name](backwardLayerAttrs)
3838

39+
// prevent GPU -> CPU data transfer by specifying non-empty outbound nodes array on internal layers
40+
this.forwardLayer.outbound = [null]
41+
this.backwardLayer.outbound = [null]
42+
3943
this.mergeMode = merge_mode
4044
this.returnSequences = layer.config.return_sequences
4145

@@ -148,6 +152,7 @@ export default class Bidirectional extends Layer {
148152
inputs: [{ texture: x.glTexture, type: '2d', name: 'source' }]
149153
})
150154

155+
// run internal component layers
151156
this.forwardLayer._callGPU(x)
152157
this.backwardLayer._callGPU(this.inputCopy)
153158
const forwardOutput = this.forwardLayer.output
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#version 300 es
2+
precision highp float;
3+
4+
in vec2 outTex;
5+
uniform sampler2D outputCopy;
6+
uniform sampler2D sliceOutput;
7+
uniform int t;
8+
uniform int timesteps;
9+
out vec4 outColor;
10+
11+
void main() {
12+
ivec2 size = textureSize(sliceOutput, 0);
13+
int out_x = int(float(size[0]) * outTex.x);
14+
int out_y = int(float(timesteps) * outTex.y);
15+
16+
if (t == out_y) {
17+
outColor = vec4(texelFetch(sliceOutput, ivec2(out_x, 0), 0).r);
18+
} else {
19+
outColor = texelFetch(outputCopy, ivec2(out_x, out_y), 0);
20+
}
21+
}
Lines changed: 289 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import Layer from '../../Layer'
22
import Tensor from '../../Tensor'
3+
import { webgl2 } from '../../WebGL2'
34
import ops from 'ndarray-ops'
5+
import * as layers from '../'
46

57
/**
68
* TimeDistributed wrapper layer class
@@ -17,8 +19,24 @@ export default class TimeDistributed extends Layer {
1719

1820
const { layer } = attrs
1921

20-
if (!layer) this.throwError('wrapped layer is undefined.')
21-
this.layer = layer
22+
if (!layer) {
23+
this.throwError('wrapped layer is undefined.')
24+
}
25+
26+
const wrappedLayerAttrs = Object.assign({}, layer.config, { gpu: attrs.gpu })
27+
this.wrappedLayer = new layers[layer.class_name](wrappedLayerAttrs)
28+
29+
// prevent GPU -> CPU data transfer by specifying non-empty outbound nodes array on internal layer
30+
this.wrappedLayer.outbound = [null]
31+
32+
// GPU setup
33+
if (this.gpu) {
34+
this.copyTextureProgram = webgl2.compileProgram(require('../../copyTexture.glsl'))
35+
this.mapInputProgram = webgl2.compileProgram(require('../../mapInput.glsl'))
36+
this.selectSliceProgram = webgl2.compileProgram(require('./TimeDistributed.selectSlice.glsl'))
37+
this.copySliceOutputProgram = webgl2.compileProgram(require('./TimeDistributed.copySliceOutput.glsl'))
38+
this.mapSliceOutputProgram = webgl2.compileProgram(require('./TimeDistributed.mapSliceOutput.glsl'))
39+
}
2240
}
2341

2442
/**
@@ -28,32 +46,282 @@ export default class TimeDistributed extends Layer {
2846
* @param {Tensor[]} weightsArr - array of weights which are instances of Tensor
2947
*/
3048
setWeights(weightsArr) {
31-
this.layer.setWeights(weightsArr)
49+
this.wrappedLayer.setWeights(weightsArr)
3250
}
3351

3452
/**
35-
* Method for layer computational logic
53+
* Layer computational logic
3654
*
3755
* @param {Tensor} x
3856
* @returns {Tensor}
3957
*/
4058
call(x) {
41-
const xStepShape = [...x.tensor.shape.slice(1)]
42-
let xStep = new Tensor([], xStepShape)
43-
ops.assign(xStep.tensor, x.tensor.pick(0, ...xStepShape.map(s => null)))
44-
let yStep = this.layer.call(xStep)
45-
const yStepShape = yStep.tensor.shape.slice()
46-
let y = new Tensor([], [x.tensor.shape[0], ...yStepShape])
47-
ops.assign(y.tensor.pick(0, ...yStepShape.map(s => null)), yStep.tensor)
48-
49-
for (let i = 1, steps = x.tensor.shape[0]; i < steps; i++) {
50-
let xStep = new Tensor([], xStepShape)
51-
ops.assign(xStep.tensor, x.tensor.pick(i, ...xStepShape.map(s => null)))
52-
yStep = this.layer.call(xStep)
53-
ops.assign(y.tensor.pick(i, ...yStepShape.map(s => null)), yStep.tensor)
54-
}
55-
56-
x.tensor = y.tensor
57-
return x
59+
if (this.gpu) {
60+
this._callGPU(x)
61+
} else {
62+
this._callCPU(x)
63+
}
64+
return this.output
65+
}
66+
67+
/**
68+
* CPU call
69+
*
70+
* @param {Tensor} x
71+
*/
72+
_callCPU(x) {
73+
const stepShape = [...x.tensor.shape.slice(1)]
74+
const step = new Tensor([], stepShape)
75+
ops.assign(step.tensor, x.tensor.pick(0, ...Array(stepShape.length).fill(null)))
76+
let stepOutput = this.wrappedLayer.call(step)
77+
const stepOutputShape = stepOutput.tensor.shape.slice()
78+
this.output = new Tensor([], [x.tensor.shape[0], ...stepOutputShape])
79+
ops.assign(this.output.tensor.pick(0, ...Array(stepOutputShape.length).fill(null)), stepOutput.tensor)
80+
for (let i = 1, timesteps = x.tensor.shape[0]; i < timesteps; i++) {
81+
ops.assign(step.tensor, x.tensor.pick(i, ...Array(stepShape.length).fill(null)))
82+
stepOutput = this.wrappedLayer.call(step)
83+
ops.assign(this.output.tensor.pick(i, ...Array(stepOutputShape.length).fill(null)), stepOutput.tensor)
84+
}
85+
}
86+
87+
/**
88+
* Creates row/col index mappings to map input texture to time-distributed slices
89+
*
90+
* @param {Object} indicesForReshaped
91+
*/
92+
_createIndexMap(indicesForReshaped) {
93+
if (this.rowIndexMaps && this.colIndexMaps) {
94+
return
95+
}
96+
97+
const indicesRow = new Tensor(indicesForReshaped.row.data, indicesForReshaped.row.shape, { type: Int32Array })
98+
const indicesCol = new Tensor(indicesForReshaped.col.data, indicesForReshaped.col.shape, { type: Int32Array })
99+
100+
this.rowIndexMaps = []
101+
this.colIndexMaps = []
102+
103+
const timesteps = this.inputShape[0]
104+
const sliceShape = this.inputShape.slice(1)
105+
for (let t = 0; t < timesteps; t++) {
106+
const sliceIndicesRow = new Tensor([], sliceShape, { type: Int32Array })
107+
const sliceIndicesCol = new Tensor([], sliceShape, { type: Int32Array })
108+
ops.assign(sliceIndicesRow.tensor, indicesRow.tensor.pick(t, ...Array(sliceShape.length).fill(null)))
109+
ops.assign(sliceIndicesCol.tensor, indicesCol.tensor.pick(t, ...Array(sliceShape.length).fill(null)))
110+
sliceIndicesRow.reshapeTo2DSquare()
111+
sliceIndicesCol.reshapeTo2DSquare()
112+
sliceIndicesRow.createGLTexture('2d', 'int')
113+
sliceIndicesCol.createGLTexture('2d', 'int')
114+
this.rowIndexMaps.push(sliceIndicesRow)
115+
this.colIndexMaps.push(sliceIndicesCol)
116+
}
117+
}
118+
119+
/**
120+
* Creates row/col index mappings to map time-distributed slices to output texture
121+
*
122+
* @param {Object} indicesForReshaped
123+
*/
124+
_createOutputIndexMap(indicesForReshaped) {
125+
if (this.outputRowIndexMaps && this.outputColIndexMaps) {
126+
return
127+
}
128+
129+
const outputSliceIndicesRow = new Tensor(indicesForReshaped.row.data, indicesForReshaped.row.shape, {
130+
type: Int32Array
131+
})
132+
const outputSliceIndicesCol = new Tensor(indicesForReshaped.col.data, indicesForReshaped.col.shape, {
133+
type: Int32Array
134+
})
135+
136+
this.outputRowIndexMaps = []
137+
this.outputColIndexMaps = []
138+
139+
const timesteps = this.outputShape[0]
140+
const sliceShape = this.outputShape.slice(1)
141+
for (let t = 0; t < timesteps; t++) {
142+
const outputIndicesRow = new Tensor([], this.outputShape, { type: Int32Array })
143+
const outputIndicesCol = new Tensor([], this.outputShape, { type: Int32Array })
144+
ops.assigns(outputIndicesRow.tensor, -1)
145+
ops.assigns(outputIndicesCol.tensor, -1)
146+
ops.assign(outputIndicesRow.tensor.pick(t, ...Array(sliceShape.length).fill(null)), outputSliceIndicesRow.tensor)
147+
ops.assign(outputIndicesCol.tensor.pick(t, ...Array(sliceShape.length).fill(null)), outputSliceIndicesCol.tensor)
148+
outputIndicesRow.reshapeTo2DSquare()
149+
outputIndicesCol.reshapeTo2DSquare()
150+
outputIndicesRow.createGLTexture('2d', 'int')
151+
outputIndicesCol.createGLTexture('2d', 'int')
152+
this.outputRowIndexMaps.push(outputIndicesRow)
153+
this.outputColIndexMaps.push(outputIndicesCol)
154+
}
155+
}
156+
157+
/**
158+
* GPU call
159+
*
160+
* @param {Tensor} x
161+
*/
162+
_callGPU(x) {
163+
if (x.is2DReshaped) {
164+
this.inputShape = x.originalShape
165+
} else {
166+
this.inputShape = x.tensor.shape
167+
}
168+
169+
if (!x.glTexture) {
170+
if (x.tensor.shape.length <= 2) {
171+
x.createGLTexture()
172+
} else if (x.tensor.shape.length > 2 && !x.is2DReshaped) {
173+
x.reshapeTo2DSquare()
174+
x.createGLTexture()
175+
}
176+
}
177+
178+
if (this.inputShape.length > 2) {
179+
this._createIndexMap(x.indicesForReshaped)
180+
}
181+
182+
const timesteps = this.inputShape[0]
183+
const sliceShape = this.inputShape.slice(1)
184+
185+
if (!this.slice) {
186+
this.slice = new Tensor([], sliceShape)
187+
if (sliceShape.length <= 2) {
188+
this.slice.createGLTexture()
189+
} else {
190+
this.slice.reshapeTo2DSquare()
191+
this.slice.createGLTexture()
192+
}
193+
}
194+
195+
if (this.inputShape.length <= 2) {
196+
webgl2.runProgram({
197+
program: this.selectSliceProgram,
198+
output: this.slice,
199+
inputs: [{ texture: x.glTexture, type: '2d', name: 'x' }],
200+
uniforms: [{ value: 0, type: 'int', name: 't' }]
201+
})
202+
} else {
203+
webgl2.runProgram({
204+
program: this.mapInputProgram,
205+
output: this.slice,
206+
inputs: [
207+
{ texture: x.glTexture, type: '2d', name: 'x' },
208+
{ texture: this.rowIndexMaps[0].glTexture, type: '2d', name: 'rowIndexMap' },
209+
{ texture: this.colIndexMaps[0].glTexture, type: '2d', name: 'colIndexMap' }
210+
]
211+
})
212+
}
213+
214+
this.wrappedLayer._callGPU(this.slice)
215+
this.sliceOutput = this.wrappedLayer.output
216+
217+
if (!this.output) {
218+
if (this.inputShape.length <= 2) {
219+
this.outputShape = [timesteps, this.sliceOutput.glTextureShape[1]]
220+
this.output = new Tensor([], this.outputShape)
221+
this.outputCopy = new Tensor([], this.outputShape)
222+
this.output.createGLTexture()
223+
this.outputCopy.createGLTexture()
224+
} else {
225+
this.outputShape = [timesteps, ...this.sliceOutput.originalShape]
226+
this.output = new Tensor([], this.outputShape)
227+
this.outputCopy = new Tensor([], this.outputShape)
228+
this.output.reshapeTo2DSquare()
229+
this.outputCopy.reshapeTo2DSquare()
230+
this.output.createGLTexture()
231+
this.outputCopy.createGLTexture()
232+
233+
this._createOutputIndexMap(this.sliceOutput.indicesForReshaped)
234+
}
235+
}
236+
237+
webgl2.runProgram({
238+
program: this.copyTextureProgram,
239+
output: this.outputCopy,
240+
inputs: [{ texture: this.output.glTexture, type: '2d', name: 'source' }]
241+
})
242+
243+
if (this.inputShape.length <= 2) {
244+
webgl2.runProgram({
245+
program: this.copySliceOutputProgram,
246+
output: this.output,
247+
inputs: [
248+
{ texture: this.outputCopy.glTexture, type: '2d', name: 'outputCopy' },
249+
{ texture: this.sliceOutput.glTexture, type: '2d', name: 'sliceOutput' }
250+
],
251+
uniforms: [{ value: 0, type: 'int', name: 't' }, { value: timesteps, type: 'int', name: 'timesteps' }]
252+
})
253+
} else {
254+
webgl2.runProgram({
255+
program: this.mapSliceOutputProgram,
256+
output: this.output,
257+
inputs: [
258+
{ texture: this.outputCopy.glTexture, type: '2d', name: 'outputCopy' },
259+
{ texture: this.sliceOutput.glTexture, type: '2d', name: 'sliceOutput' },
260+
{ texture: this.outputRowIndexMaps[0].glTexture, type: '2d', name: 'rowIndexMap' },
261+
{ texture: this.outputColIndexMaps[0].glTexture, type: '2d', name: 'colIndexMap' }
262+
]
263+
})
264+
}
265+
266+
for (let i = 1; i < timesteps; i++) {
267+
if (this.inputShape.length <= 2) {
268+
webgl2.runProgram({
269+
program: this.selectSliceProgram,
270+
output: this.slice,
271+
inputs: [{ texture: x.glTexture, type: '2d', name: 'x' }],
272+
uniforms: [{ value: i, type: 'int', name: 't' }]
273+
})
274+
} else {
275+
webgl2.runProgram({
276+
program: this.mapInputProgram,
277+
output: this.slice,
278+
inputs: [
279+
{ texture: x.glTexture, type: '2d', name: 'x' },
280+
{ texture: this.rowIndexMaps[i].glTexture, type: '2d', name: 'rowIndexMap' },
281+
{ texture: this.colIndexMaps[i].glTexture, type: '2d', name: 'colIndexMap' }
282+
]
283+
})
284+
}
285+
286+
this.wrappedLayer._callGPU(this.slice)
287+
this.sliceOutput = this.wrappedLayer.output
288+
289+
webgl2.runProgram({
290+
program: this.copyTextureProgram,
291+
output: this.outputCopy,
292+
inputs: [{ texture: this.output.glTexture, type: '2d', name: 'source' }]
293+
})
294+
295+
if (this.inputShape.length <= 2) {
296+
webgl2.runProgram({
297+
program: this.copySliceOutputProgram,
298+
output: this.output,
299+
inputs: [
300+
{ texture: this.outputCopy.glTexture, type: '2d', name: 'outputCopy' },
301+
{ texture: this.sliceOutput.glTexture, type: '2d', name: 'sliceOutput' }
302+
],
303+
uniforms: [{ value: i, type: 'int', name: 't' }, { value: timesteps, type: 'int', name: 'timesteps' }]
304+
})
305+
} else {
306+
webgl2.runProgram({
307+
program: this.mapSliceOutputProgram,
308+
output: this.output,
309+
inputs: [
310+
{ texture: this.outputCopy.glTexture, type: '2d', name: 'outputCopy' },
311+
{ texture: this.sliceOutput.glTexture, type: '2d', name: 'sliceOutput' },
312+
{ texture: this.outputRowIndexMaps[i].glTexture, type: '2d', name: 'rowIndexMap' },
313+
{ texture: this.outputColIndexMaps[i].glTexture, type: '2d', name: 'colIndexMap' }
314+
]
315+
})
316+
}
317+
}
318+
319+
// GPU -> CPU data transfer
320+
if (this.outbound.length === 0) {
321+
this.output.transferFromGLTexture()
322+
if (this.output.is2DReshaped) {
323+
this.output.reshapeFrom2DSquare()
324+
}
325+
}
58326
}
59327
}

0 commit comments

Comments
 (0)