Skip to content

Commit efc7a74

Browse files
committed
fix #1111 #1037 remove redundant unwrap_model for AcceleratedOptimizer; which has no attribute '_modules' thus conflict with has_compiled_regions check introduced in accelerate v1.7.0
1 parent 9842314 commit efc7a74

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ classifiers = [
1414
"Programming Language :: Python :: 3",
1515
]
1616
dependencies = [
17-
"accelerate>=0.33.0,!=1.7.0",
17+
"accelerate>=0.33.0",
1818
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
1919
"cached_path",
2020
"click",

src/f5_tts/model/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def save_checkpoint(self, update, last=False):
149149
if self.is_main:
150150
checkpoint = dict(
151151
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
152-
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
152+
optimizer_state_dict=self.optimizer.state_dict(),
153153
ema_model_state_dict=self.ema_model.state_dict(),
154154
scheduler_state_dict=self.scheduler.state_dict(),
155155
update=update,

0 commit comments

Comments
 (0)