-
Notifications
You must be signed in to change notification settings - Fork 52
Simplify, correct, and add validation for GRU/LSTM and friends #659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify, correct, and add validation for GRU/LSTM and friends #659
Conversation
- A couple of places were comparing a rank vs. an expected dimension (e.g. "rank is not equal to 3 * hiddenSize"). Fix these! - Rather than validating for example validating that rank = 2, shape[0] = N and shape[1] = M, just compare shape against « N, M ». - While doing the above, several arguments had their data type and rank validated, but only some of the dimensions. Make this consistent across the ops - at least, matching the existing prose.
|
The fun never stops! I may be wrong and the missing validation is intentional - e.g. for lstmCell, cellState's data type and rank are validated, but not the actual dimensions called out in the prose ("The 2-D input cell state tensor of shape [batchSize, hiddenSize].") |
|
Thanks for the close look, @huningxin - all those weight/recurrentWeight/hiddenSize/inputSize blur together after a while. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
I noticed one more case where this could be applied: instanceNormalization - bundled it into this PR since it was on topic. Done in fcf0479 |
This is nice and concise. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 TY JB.
SHA: aa5fac9 Reason: push, by fdwr Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Some steps in gruCell() were comparing a rank vs. an expected dimension (e.g. "rank is not equal to 3 * hiddenSize"). Fix these!
Rather than validating for example that rank = 2, shape[0] = N and shape[1] = M, just compare shape against « N, M ». This also implicitly fixes places that were inspecting shape[x] without validating the rank first. Done for: batchNormalization(), conv2d(), convTranspose2d(), gru(), gruCell(), lstm(), lstmCell().
Some places did validate data type and rank, but only some or none of the dimensions. Make this consistent across the ops - at least, matching the existing prose. Done for gru(), gruCell(), instanceNormalization(), lstm(), lstmCell().
Preview | Diff