-
Notifications
You must be signed in to change notification settings - Fork 795
Use broadcast_logical_or in inflated_beta #1226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Nice! Did you verify that this works? You could also add a test, that just sets up an estimator with the inflated distribution output, and trains it for one batch on some dummy data in [0, 1], with hybridize both True/False |
@lostella I checked that it works. I'll add a test too and push another commit |
@lostella I had my dev setup wrong when I made my earlier commits, so not all checks were being made. Now that I've corrected this, I'm getting the following:
It would be easy enough to replace these with |
d5db522
to
77e9cbf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@deejlucas thank you! I have some comments on the test, see inline.
The changes NotImplemented
-> NotImplementeError
can stay, as a matter of fact I was in the process of doing the same as part of #1223
The various inflated beta classes used F.logical_or to mask ones and zeros, but this function does not exist for mxnet.symbol. We now use broadcast_logical_or which exists as a method of both mxnet.ndarray and mxnet.symbol
4c4b8b2
to
3dba71f
Compare
@lostella Thank you for the feedback. I've addressed it. Ready for re-review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thanks for spotting this and fixing it!
The various inflated beta classes used F.logical_or to mask ones and zeros, but this function does not exist for mxnet.symbol. We now use broadcast_logical_or which exists as a method of both mxnet.ndarray and mxnet.symbol
The various inflated beta classes used F.logical_or to mask ones and zeros, but this function does not exist for mxnet.symbol. We now use broadcast_logical_or which exists as a method of both mxnet.ndarray and mxnet.symbol
* Use broadcast_logical_or in inflated_beta (#1226) * Fix type error for using quantile_weights and add a proper Pytest (#1231) * Fixed MASE in N-BEATS: removed redundant factor (#1288) * Fixing bug where dropout was not used, also remove unused halt option (#1315) * Fixes for Python 3.8 (#1318) Co-authored-by: Dan Lucas <[email protected]> Co-authored-by: youngsuk0723 <[email protected]> Co-authored-by: Danielle Robinson <[email protected]> Co-authored-by: Riccardo Grazzi <[email protected]> Co-authored-by: David Salinas <[email protected]>
Issue #, if available: addresses #1211
Description of changes: The various inflated beta distribution classes used
F.logical_or
to mask ones and zeros (for which the loss function is undefined), but there is nological_or
method formxnet.symbol
.We now use
F.broadcast_logical_or
which exists as a method of bothmxnet.ndarray
andmxnet.symbol
, either of which can be imported asF
.By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.