-
Notifications
You must be signed in to change notification settings - Fork 287
load_file: load tensors ordered by their offsets #571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I do not yet understand the stats of torch vs. others: In old version, sdxl is faster with np than pt, but the result is reversed for revAnimated. |
539b255 to
6f24f48
Compare
|
@yousong do you mind sharing more in detail the setup ?
Sequential reads should only impact old drives (HDD). Afaik, it shouldn't impact SSD at all. I'm not against doing offset based reading but if it has an actual impact, then maybe I should revisit the saving methods to enforce some kind of order. |
|
For simpler implementation, maybe we just add |
OS, python version is not very relevant here. The test stats posted above is on current tip version of this repo. We store safetensors file on S3-like object store and mount it with FUSE daemons like ossfs. In the end it's all HTTP req/resp with long latency. ossfs is a fork of s3fs. It can do direct_read which prefetches data into memory directly (vs. local tmp store). Prefetch is needed to maximize (network) throughput and sequential read is a requirement to make prefetch effective and useful in this case. |
|
Link to the test files
Simple test code: import time
import sys
import safetensors
import safetensors.torch
def safe_open(fname, framework='np'):
f = safetensors.safe_open(fname, framework=framework)
keys = f.keys()
#keys_ordered = sorted(keys[:])
tensors = {}
for k in keys:
#print(k)
tensors[k] = f.get_tensor(k)
return tensors
def torch_open(fname):
tensors = safetensors.torch.load_file(fname)
return tensors
#print(safetensors.__file__)
fname = sys.argv[1]
framework = 'torch'
if len(sys.argv) > 2:
framework = sys.argv[2]
stime = time.perf_counter()
if framework == 'pt':
torch_open(fname)
else:
safe_open(fname)
etime = time.perf_counter()
print('{:.2f}'.format(etime-stime)) |
I prepared a branch with the |
6f24f48 to
84a6bd1
Compare
|
LGTM. Thanks |
84a6bd1 to
03fb4c5
Compare
(cherry picked from commit 467a605)
(cherry picked from commit 467a605)
(cherry picked from commit 467a605)
|
Hi @Narsil, may I know when do you plan to release a new version including this patch? In our tests, this sequential read optimization does significantly reduce the model loading time. |

What does this PR do?
This change introduces a new method
offset_keys()returning names of tensors ordered by their offsets in the file. We then changeload_file()func to use the new method so that they will try to load data in a sequential way, independent of naming of tensor names.We put model files on S3-like object storages and use prefetch strategy to maximize throughput. The change can greatly shorten load time for files like
revAnimated_v122.safetensors, and has negligible effects for others (e.g.sd_xl_base_1.0.safetensors)Load time in seconds.