Skip to content

Commit 532fbe8

Browse files
authored
Merge pull request #166 from cocktailpeanut/wandb_usability
User-friendly wandb support
2 parents 8831701 + ae6e97b commit 532fbe8

File tree

2 files changed

+45
-20
lines changed

2 files changed

+45
-20
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,27 @@ An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discuss
7272

7373
Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
7474

75+
## Wandb Logging
76+
77+
By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
78+
79+
To turn on wandb logging, you can either:
80+
81+
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
82+
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
83+
84+
On Mac & Linux:
85+
86+
```
87+
export WANDB_API_KEY=<YOUR WANDB API KEY>
88+
```
89+
90+
On Windows:
91+
92+
```
93+
set WANDB_API_KEY=<YOUR WANDB API KEY>
94+
```
95+
7596
## Inference
7697

7798
The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.

model/trainer.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,31 +50,35 @@ def __init__(
5050

5151
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
5252

53+
logger = "wandb" if wandb.api.api_key else None
54+
print(f"Using logger: {logger}")
55+
5356
self.accelerator = Accelerator(
54-
log_with = "wandb",
57+
log_with = logger,
5558
kwargs_handlers = [ddp_kwargs],
5659
gradient_accumulation_steps = grad_accumulation_steps,
5760
**accelerate_kwargs
5861
)
59-
60-
if exists(wandb_resume_id):
61-
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62-
else:
63-
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64-
self.accelerator.init_trackers(
65-
project_name = wandb_project,
66-
init_kwargs=init_kwargs,
67-
config={"epochs": epochs,
68-
"learning_rate": learning_rate,
69-
"num_warmup_updates": num_warmup_updates,
70-
"batch_size": batch_size,
71-
"batch_size_type": batch_size_type,
72-
"max_samples": max_samples,
73-
"grad_accumulation_steps": grad_accumulation_steps,
74-
"max_grad_norm": max_grad_norm,
75-
"gpus": self.accelerator.num_processes,
76-
"noise_scheduler": noise_scheduler}
77-
)
62+
63+
if logger == "wandb":
64+
if exists(wandb_resume_id):
65+
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
66+
else:
67+
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
68+
self.accelerator.init_trackers(
69+
project_name = wandb_project,
70+
init_kwargs=init_kwargs,
71+
config={"epochs": epochs,
72+
"learning_rate": learning_rate,
73+
"num_warmup_updates": num_warmup_updates,
74+
"batch_size": batch_size,
75+
"batch_size_type": batch_size_type,
76+
"max_samples": max_samples,
77+
"grad_accumulation_steps": grad_accumulation_steps,
78+
"max_grad_norm": max_grad_norm,
79+
"gpus": self.accelerator.num_processes,
80+
"noise_scheduler": noise_scheduler}
81+
)
7882

7983
self.model = model
8084

0 commit comments

Comments
 (0)