Skip to content

Simplify, correct, and add validation for GRU/LSTM and friends #659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

inexorabletash
Copy link
Contributor

@inexorabletash inexorabletash commented Apr 27, 2024

  • Some steps in gruCell() were comparing a rank vs. an expected dimension (e.g. "rank is not equal to 3 * hiddenSize"). Fix these!

  • Rather than validating for example that rank = 2, shape[0] = N and shape[1] = M, just compare shape against « N, M ». This also implicitly fixes places that were inspecting shape[x] without validating the rank first. Done for: batchNormalization(), conv2d(), convTranspose2d(), gru(), gruCell(), lstm(), lstmCell().

  • Some places did validate data type and rank, but only some or none of the dimensions. Make this consistent across the ops - at least, matching the existing prose. Done for gru(), gruCell(), instanceNormalization(), lstm(), lstmCell().


Preview | Diff

- A couple of places were comparing a rank vs. an expected dimension
  (e.g. "rank is not equal to 3 * hiddenSize"). Fix these!

- Rather than validating for example validating that rank = 2,
  shape[0] = N and shape[1] = M, just compare shape against « N, M ».

- While doing the above, several arguments had their data type and
  rank validated, but only some of the dimensions. Make this
  consistent across the ops - at least, matching the existing prose.
@inexorabletash inexorabletash marked this pull request as ready for review April 27, 2024 01:18
@inexorabletash
Copy link
Contributor Author

The fun never stops!

I may be wrong and the missing validation is intentional - e.g. for lstmCell, cellState's data type and rank are validated, but not the actual dimensions called out in the prose ("The 2-D input cell state tensor of shape [batchSize, hiddenSize].")

@inexorabletash
Copy link
Contributor Author

Thanks for the close look, @huningxin - all those weight/recurrentWeight/hiddenSize/inputSize blur together after a while.

Copy link
Contributor

@huningxin huningxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@inexorabletash
Copy link
Contributor Author

I noticed one more case where this could be applied: instanceNormalization - bundled it into this PR since it was on topic. Done in fcf0479

@fdwr
Copy link
Collaborator

fdwr commented May 2, 2024

Rather than validating for example that rank = 2, shape[0] = N and shape[1] = M, just compare shape against « N, M ».

This is nice and concise.

Copy link
Collaborator

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 TY JB.

@fdwr fdwr merged commit aa5fac9 into webmachinelearning:main May 2, 2024
@inexorabletash inexorabletash deleted the gru-lstm-validation-fixes branch May 2, 2024 04:05
github-actions bot added a commit that referenced this pull request May 2, 2024
SHA: aa5fac9
Reason: push, by fdwr

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants