-
Notifications
You must be signed in to change notification settings - Fork 466
Move fqn mapping logic to StateDictAdapter #1557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…e this logic should be shared by all StateDictAdapters
8f5ee2c
to
3212a1f
Compare
from .model import BaseModelArgs | ||
|
||
|
||
class StateDictAdapter(ABC): | ||
class BaseStateDictAdapter(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe don't need this BaseStateDictAdapter
-- could you think of a case where people would like to inherit BaseStateDictAdapter
but not StateDictAdapter
? What do you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if it would be better to handle state dict adapter this way, but there are some multi-modal repositories that may have multiple safetensors.index.jsons
such as even https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main. I think you could handle such cases by using sd_adapter on each of the individual sub-models, but perhaps someone may prefer having them all under the same class? In the case of Flux, the extra models are just autoencoders that are not being trained, but perhaps someone may want to use torchtitan
in a way that trains multiple models at the same time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of FLUX, how would it work?
- Does our approach still work? It sounds to me that we are not ready to read https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main/transformer
- We don't support loading from multiple models from multiple folders anyway, so not sure if overgeneralizing helps. In the future if we support that, we probably need multiple
StateDictAdapter
?
All that said I'm OK with this change, but probably need to change https://github.com/pytorch/torchtitan/blob/main/torchtitan/protocols/train_spec.py#L24 to use the base one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think our approach will still work currently for all models using tokenizer_path since its logic for downloading tokenizer files is identical to previous download_tokenizer
script. For loading the model in Flux we can pass the hf_assets_path
of the full repo to FluxStateDictAdapter and then FluxStateDictAdapter can pass hf_assets_path + "transformer"
to the parent class. I agree though that I think I can iterate on the download_hf_assets
options to narrow search to certain subfolders, or pattern match on full path names instead of base names.
…DictAdapter and StateDictAdapter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This moves the logic that parses
model.safetensors.index.json
and generates thefqn_to_index_mapping
toStateDictAdapter
since this logic should be shared by all classes that inherit fromStateDictAdapter
.