Skip to content

Use run name for logging with WandbLogger #12604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise an error if there are insufficient training batches when using a float value of `limit_train_batches` ([#12885](https://github.com/PyTorchLightning/pytorch-lightning/pull/12885))


- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))



### Deprecated
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(
# set wandb init arguments
anonymous_lut = {True: "allow", False: None}
self._wandb_init = dict(
name=name,
name=name or project,
project=project,
id=version or id,
dir=save_dir,
Expand All @@ -316,7 +316,7 @@ def __getstate__(self):
if self._experiment is not None:
state["_id"] = getattr(self._experiment, "id", None)
state["_attach_id"] = getattr(self._experiment, "_attach_id", None)
state["_name"] = self._experiment.project_name()
state["_name"] = self._experiment.name

# cannot be pickled
state["_experiment"] = None
Expand Down Expand Up @@ -449,7 +449,7 @@ def name(self) -> Optional[str]:
The name of the experiment if the experiment exists else the name given to the constructor.
"""
# don't create an experiment if we don't have one
return self._experiment.project_name() if self._experiment else self._name
return self._experiment.name if self._experiment else self._name

@property
def version(self) -> Optional[str]:
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def log_metrics(self, metrics, step):
if logger_class == WandbLogger:
# required mocks for Trainer
logger.experiment.id = "foo"
logger.experiment.project_name.return_value = "bar"
logger.experiment.name = "bar"

if logger_class == CometLogger:
logger.experiment.id = "foo"
Expand Down Expand Up @@ -299,7 +299,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):


@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES_WO_NEPTUNE_WANDB)
@RunIf(skip_windows=True)
@RunIf(skip_windows=True, skip_hanging_spawn=True)
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
"""Test that loggers get replaced by dummy loggers on global rank > 0."""
_patch_comet_atexit(monkeypatch)
Expand Down
31 changes: 20 additions & 11 deletions tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def test_wandb_logger_init(wandb, monkeypatch):
)
wandb.init().log.assert_called_once_with({"acc": 1.0})

# test wandb.init called with project as name if name not provided
wandb.run = None
wandb.init.reset_mock()
WandbLogger(project="test_project").experiment
wandb.init.assert_called_once_with(
name="test_project", dir=None, id=None, project="test_project", resume="allow", anonymous=None
)

# test wandb.init and setting logger experiment externally
wandb.run = None
run = wandb.init()
Expand Down Expand Up @@ -83,7 +91,7 @@ def test_wandb_logger_init(wandb, monkeypatch):
logger.watch("model", "log", 10, False)
wandb.init().watch.assert_called_once_with("model", log="log", log_freq=10, log_graph=False)

assert logger.name == wandb.init().project_name()
assert logger.name == wandb.init().name
assert logger.version == wandb.init().id


Expand All @@ -99,8 +107,9 @@ class Experiment:
step = 0
dir = "wandb"

def project_name(self):
return "the_project_name"
@property
def name(self):
return "the_run_name"

wandb.run = None
wandb.init.return_value = Experiment()
Expand Down Expand Up @@ -134,18 +143,18 @@ def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir):
logger = WandbLogger(save_dir=str(tmpdir), offline=True)
# the logger get initialized
assert logger.version == wandb.init().id
assert logger.name == wandb.init().project_name()
assert logger.name == wandb.init().name

# mock return values of experiment
wandb.run = None
logger.experiment.id = "1"
logger.experiment.project_name.return_value = "project"
logger.experiment.name = "run_name"

for _ in range(2):
_ = logger.experiment

assert logger.version == "1"
assert logger.name == "project"
assert logger.name == "run_name"
assert str(tmpdir) == logger.save_dir
assert not os.listdir(tmpdir)

Expand All @@ -155,7 +164,7 @@ def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir):
assert trainer.log_dir == logger.save_dir
trainer.fit(model)

assert trainer.checkpoint_callback.dirpath == str(tmpdir / "project" / version / "checkpoints")
assert trainer.checkpoint_callback.dirpath == str(tmpdir / "run_name" / version / "checkpoints")
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=3.ckpt"}
assert trainer.log_dir == logger.save_dir

Expand All @@ -173,7 +182,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
# test log_model=True
logger = WandbLogger(log_model=True)
logger.experiment.id = "1"
logger.experiment.project_name.return_value = "project"
logger.experiment.name = "run_name"
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
wandb.init().log_artifact.assert_called_once()
Expand All @@ -183,7 +192,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
wandb.init.reset_mock()
logger = WandbLogger(log_model="all")
logger.experiment.id = "1"
logger.experiment.project_name.return_value = "project"
logger.experiment.name = "run_name"
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
assert wandb.init().log_artifact.call_count == 2
Expand All @@ -193,7 +202,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
wandb.init.reset_mock()
logger = WandbLogger(log_model=False)
logger.experiment.id = "1"
logger.experiment.project_name.return_value = "project"
logger.experiment.name = "run_name"
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
assert not wandb.init().log_artifact.called
Expand All @@ -204,7 +213,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
wandb.Artifact.reset_mock()
logger = WandbLogger(log_model=True)
logger.experiment.id = "1"
logger.experiment.project_name.return_value = "project"
logger.experiment.name = "run_name"
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
trainer.fit(model)
wandb.Artifact.assert_called_once_with(
Expand Down