Skip to content

Move fqn mapping logic to StateDictAdapter #1557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 12, 2025
Merged

Conversation

wesleytruong
Copy link
Contributor

This moves the logic that parses model.safetensors.index.json and generates the fqn_to_index_mapping to StateDictAdapter since this logic should be shared by all classes that inherit from StateDictAdapter.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 12, 2025
…e this logic should be shared by all StateDictAdapters
@wesleytruong wesleytruong force-pushed the move_sd_adapter_logic branch from 8f5ee2c to 3212a1f Compare August 12, 2025 16:57
from .model import BaseModelArgs


class StateDictAdapter(ABC):
class BaseStateDictAdapter(ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe don't need this BaseStateDictAdapter -- could you think of a case where people would like to inherit BaseStateDictAdapter but not StateDictAdapter? What do you think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it would be better to handle state dict adapter this way, but there are some multi-modal repositories that may have multiple safetensors.index.jsons such as even https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main. I think you could handle such cases by using sd_adapter on each of the individual sub-models, but perhaps someone may prefer having them all under the same class? In the case of Flux, the extra models are just autoencoders that are not being trained, but perhaps someone may want to use torchtitan in a way that trains multiple models at the same time

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of FLUX, how would it work?

  1. Does our approach still work? It sounds to me that we are not ready to read https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main/transformer
  2. We don't support loading from multiple models from multiple folders anyway, so not sure if overgeneralizing helps. In the future if we support that, we probably need multiple StateDictAdapter?

All that said I'm OK with this change, but probably need to change https://github.com/pytorch/torchtitan/blob/main/torchtitan/protocols/train_spec.py#L24 to use the base one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our approach will still work currently for all models using tokenizer_path since its logic for downloading tokenizer files is identical to previous download_tokenizer script. For loading the model in Flux we can pass the hf_assets_path of the full repo to FluxStateDictAdapter and then FluxStateDictAdapter can pass hf_assets_path + "transformer" to the parent class. I agree though that I think I can iterate on the download_hf_assets options to narrow search to certain subfolders, or pattern match on full path names instead of base names.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good.

please resolve with @wwwjn on parallel work #1538

@wesleytruong wesleytruong merged commit 8bd8c93 into main Aug 12, 2025
7 checks passed
@tianyu-l tianyu-l deleted the move_sd_adapter_logic branch August 13, 2025 01:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants