Skip to content

Initialise batch norm stats for e2e models #524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 7, 2020
Merged

Conversation

AdamHillier
Copy link
Contributor

What do these changes do?

The default BatchNormalization moving-variance initialisation is "ones", which makes sense for a full-precision convolution (because the weights are usually initialised with very small absolute value) but doesn't make sense for a binary convolution (where the absolute value of the weights is always 1). This causes instability in the end2end tests because it means that the magnitude of the convolution outputs is unrealistic. In places, we've attempted to correct for this by fiddling with the gamma/beta initialisation, but the real problem is the moving-variance statistics.

This PR adds ten steps of training with a low learning rate for each model in the e2e test (excluding QuickNet, which we load with trained weights). These few steps should result in more sensible values for the batch norm statistics, and regardless I think this makes the tests more relevent that using models with entirely random weights.

A nice benefit is that it appears that this allows us to reduce the absolute tolerance in the np.allclose comparison.

How Has This Been Tested?

CI.

Benchmark Results

N/A.

Related issue number

N/A.



def assert_model_output(model_lce, inputs, outputs):
interpreter = Interpreter(model_lce, num_threads=min(os.cpu_count(), 4))
actual_outputs = interpreter.predict(inputs)
np.testing.assert_allclose(actual_outputs, outputs, rtol=0.001, atol=0.25)
np.testing.assert_allclose(actual_outputs, outputs, rtol=0.05, atol=0.125)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it really made sense to have such a small rtol but huge atol (with rtol/atol 0.001/0.25, the rtol would only really be relevant for values with absolute value > 100), so I've bumped up the rtol as well as reducing the atol.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that 0.25 is huge. However it's what we observed for the keras-tflite difference for large models like Quicknet.
I think it might be better to pass a tolerance (or Default) to each model in model_cls so that we have normal tolerances everywhere and make it higher exclusively for QuickNet. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, that's a good idea 👍

atol = 0.025
else:
rtol = 0.001
atol = 0.001
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tombana yeah you're right, turns out we can use way lower tolerances for the non-QuickNet models. And probably should do.

@AdamHillier AdamHillier marked this pull request as ready for review October 5, 2020 10:04
@AdamHillier AdamHillier added the test everything related to testsuites label Oct 5, 2020
@lgeiger lgeiger added the internal-improvement Internal Improvements and Maintenance label Oct 7, 2020
Copy link
Member

@lgeiger lgeiger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@AdamHillier AdamHillier merged commit 24fd8a4 into master Oct 7, 2020
@AdamHillier AdamHillier deleted the e2e-bn-stats branch October 7, 2020 21:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
internal-improvement Internal Improvements and Maintenance test everything related to testsuites
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants