diff --git a/trellis/trainers/base.py b/trellis/trainers/base.py index 15463a08..82e2d83f 100644 --- a/trellis/trainers/base.py +++ b/trellis/trainers/base.py @@ -230,7 +230,7 @@ def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False): # Assign tasks num_samples_per_process = int(np.ceil(num_samples / self.world_size)) - samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose) + samples = self.run_snapshot(num_samples_per_process, batch_size=self.batch_size_per_gpu, verbose=verbose) # Preprocess images for key in list(samples.keys()): @@ -448,4 +448,4 @@ def profile(self, wait=2, warmup=3, active=5): for _ in range(wait + warmup + active): self.run_step() prof.step() - \ No newline at end of file +