-
Notifications
You must be signed in to change notification settings - Fork 287
Support paddle save/load/save_file/load_file without coverting to numpy #646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Could you please review this PR? @Narsil Thank you. |
Narsil
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this.
The biggest blocker is that this code is breaking paddle < 3.1.
For bflioat16 we could skip the test if < 3.1 doesn't support them.
As a side node, it seems the benchmarks don't show that big of an improvement which seems suprising.
======================================================================================================================= warnings summary =======================================================================================================================
../../.venv/lib/python3.12/site-packages/paddle/utils/cpp_extension/extension_utils.py:717
/home/nicolas/src/safetensors/.venv/lib/python3.12/site-packages/paddle/utils/cpp_extension/extension_utils.py:717: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md
warnings.warn(warning_message)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
-------------------------------------------------------------------------------------- benchmark: 2 tests -------------------------------------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_paddle_paddle_load 371.5229 (1.0) 393.8909 (1.0) 380.8332 (1.0) 8.7418 (1.19) 381.2460 (1.0) 12.2720 (1.07) 2;0 2.6258 (1.0) 5 1
test_paddle_sf_load 474.4823 (1.28) 491.2912 (1.25) 480.5302 (1.26) 7.3160 (1.0) 476.6819 (1.25) 11.4378 (1.0) 1;0 2.0810 (0.79) 5 1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Legend:
Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
OPS: Operations Per Second, computed as 1 / Mean
================================================================================================================ 2 passed, 1 warning in 19.98s
| @@ -1,10 +1,11 @@ | |||
| import os | |||
| from typing import Dict, Optional, Union | |||
| import sys | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This modification is breaking paddle < 3.1.1 which isn't great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve made several improvements to the code:
- Added version-dispatch logic for different Paddle versions.
- Refactored
loadusing the newfrombufferfunction in Paddle, which is based on MmapStorage for improved performance. The new function is now available in develop, and will be released in 3.2.0.
Could you please review the updated implementation again? Thanks a lot! @Narsil
|
Hi @Narsil ,could you please review my updated implementation? Thanks! |
| output = _np2paddle(flat, device) | ||
| return output | ||
| result = {} | ||
| if paddle.__version__ == "0.0.0" or paddle.__version__ >= "3.1.1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you using 2 different versions between deserialize and here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"0.0.0" is the develop_version for Paddle, we just hope that developers using the develop Paddle can also use this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you using 2 different versions between deserialize and here ?
Hi @Narsil , we have refined the version of Paddle in the develop branch, now it is no more "0.0.0" and can easily use one single statement to do the version control. As for the distinction of '3.1.1' and '3.2.0', that's because we recently added the frombuffer function after '3.1.1', so we should use a newer version '3.2.0' . Could you please take another look and consider merging this PR into the main branch? We're eager to release an updated version of safetensors that includes these enhancements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have 3.2.0 everywhere? Adding version checks is a maintenance burden, having as few as possible is easier to maintain (and remember).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright~ Now all versions are 3.2.0 now.
Narsil
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks !
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
Based on #630 .
This PR refactors the logic of the
save,load,save_file, andload_fileoperations inpaddle.pyto eliminate the need for NumPy conversion. As a result, Paddle can now utilize safetensors to save and load additional data types—such as bfloat16 (bf16)—that were previously unsupported.