Skip to content

Commit 9837910

Browse files
committed
changes to save and test the models
1 parent a775004 commit 9837910

File tree

5 files changed

+148
-129
lines changed

5 files changed

+148
-129
lines changed

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def init_model_fn(
105105
{'params': params_rng, 'dropout': dropout_rng},
106106
jnp.ones(input_shape, jnp.float32))
107107
initial_params = initial_variables['params']
108-
initial_params = use_pytorch_weights_inplace(initial_params, file_name="/results/pytorch_base_model_criteo1tb_24_june.pth")
108+
initial_params = use_pytorch_weights_inplace(initial_params, file_name="/results/pytorch_base_model_criteo1tb_1_july.pth")
109109
self._param_shapes = param_utils.jax_param_shapes(initial_params)
110110
self._param_types = param_utils.jax_param_types(self._param_shapes)
111111
return jax_utils.replicate(initial_params), None

algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def init_model_fn(
8888
dropout_rate=dropout_rate,
8989
use_layer_norm=self.use_layer_norm,
9090
embedding_init_multiplier=self.embedding_init_multiplier)
91+
torch.save(model.state_dict(), '/results/pytorch_base_model_criteo1tb_1_july.pth')
9192
self._param_shapes = param_utils.pytorch_param_shapes(model)
9293
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
9394
model.to(DEVICE)

custom_pytorch_jax_converter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def use_pytorch_weights_inplace(jax_params, file_name=None, replicate=False):
2626
# Load PyTorch state_dict
2727
state_dict = torch.load(file_name)
2828
print(state_dict.keys())
29+
2930
# Convert PyTorch tensors to NumPy arrays
3031
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
3132

@@ -59,6 +60,7 @@ def deep_copy_to_cpu(pytree):
5960
# Load PyTorch state_dict lazily to CPU
6061
state_dict = torch.load(file_name, map_location='cpu')
6162
print(state_dict.keys())
63+
6264
# Convert PyTorch tensors to NumPy arrays
6365
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
6466

@@ -128,6 +130,7 @@ def move_to_cpu(tree):
128130
def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
129131
"""Compares two JAX PyTrees of weights and logs where they differ, safely handling PMAP replication."""
130132
# Attempt to unreplicate if needed
133+
131134
params1 = maybe_unreplicate(params1)
132135
params2 = maybe_unreplicate(params2)
133136

reference_algorithms/schedule_free/jax/submission.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,14 @@ def update_params(workload: spec.Workload,
171171
}, global_step)
172172

173173
# Log the number of parameters.
174-
if global_step % 100 == 0 and workload.metrics_logger is not None:
175-
date_ = "2025-06-14"
176-
file_name = f"/results/schedule_free_test_pytorch_weights/criteo1tb_{date_}_after_{global_step}_steps.pth"
174+
if global_step % 100 == 0:
175+
date_ = "2025-07-01"
176+
file_name = f"/results/schedule_free_pytorch_weights/criteo1tb_{date_}_after_{global_step}_steps.pth"
177177
params = use_pytorch_weights_cpu_copy(new_params, file_name=file_name, replicate=True)
178178
are_weights_equal(new_params, params)
179179
del params
180+
181+
breakpoint()
180182

181183
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
182184

submission_runner.py

Lines changed: 138 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,19 @@ def train_once(
372372
rng=update_rng,
373373
**({'train_state': MappingProxyType(train_state)}
374374
if needs_train_state else {}))
375+
if FLAGS.framework=="pytorch" and global_step % 100 == 0:
376+
if global_step > 1000:
377+
import torch.distributed as dist
378+
import sys
379+
dist.destroy_process_group()
380+
sys.exit(0)
381+
# Save the PyTorch weights to a file every 100 steps.
382+
date_ = datetime.date.today().strftime('%Y-%m-%d')
383+
file_name = os.path.join(
384+
log_dir, f'/results/schedule_free_pytorch_weights/{workload_name}_{date_}_after_{global_step}_steps.pth')
385+
logging.info(f'Saving PyTorch weights to {file_name}')
386+
torch.save(model_params.module.state_dict(), file_name)
387+
375388
except spec.TrainingCompleteError:
376389
train_state['training_complete'] = True
377390
global_step += 1
@@ -383,131 +396,131 @@ def train_once(
383396
train_state['accumulated_submission_time'] += (
384397
train_step_end_time - train_state['last_step_end_time'])
385398

386-
# Check if submission is eligible for an untimed eval.
387-
if ((train_step_end_time - train_state['last_eval_time']) >=
388-
workload.eval_period_time_sec or train_state['training_complete']):
389-
390-
# Prepare for evaluation (timed).
391-
if prepare_for_eval is not None:
392-
393-
with profiler.profile('Prepare for eval'):
394-
del batch
395-
prepare_for_eval_start_time = get_time()
396-
optimizer_state, model_params, model_state = prepare_for_eval(
397-
workload=workload,
398-
current_param_container=model_params,
399-
current_params_types=workload.model_params_types,
400-
model_state=model_state,
401-
hyperparameters=hyperparameters,
402-
loss_type=workload.loss_type,
403-
optimizer_state=optimizer_state,
404-
eval_results=eval_results,
405-
global_step=global_step,
406-
rng=prep_eval_rng)
407-
prepare_for_eval_end_time = get_time()
408-
409-
# Update sumbission time.
410-
train_state['accumulated_submission_time'] += (
411-
prepare_for_eval_end_time - prepare_for_eval_start_time)
412-
413-
# Check if time is remaining,
414-
# use 1.5x the runtime budget for the self-tuning ruleset.
415-
max_allowed_runtime_sec = (
416-
workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external'
417-
else 1.5 * workload.max_allowed_runtime_sec)
418-
train_state['is_time_remaining'] = (
419-
train_state['accumulated_submission_time'] < max_allowed_runtime_sec)
420-
421-
# Eval if time is remaining (untimed).
422-
if train_state['is_time_remaining']:
423-
424-
with profiler.profile('Evaluation'):
425-
_reset_cuda_mem()
426-
427-
try:
428-
eval_start_time = get_time()
429-
latest_eval_result = workload.eval_model(global_eval_batch_size,
430-
model_params,
431-
model_state,
432-
eval_rng,
433-
data_dir,
434-
imagenet_v2_data_dir,
435-
global_step)
436-
# Check if targets reached.
437-
# Note that this is one of the stopping conditions for the length of
438-
# a training run. To score the run we only consider the time
439-
# to validation target retrospectively.
440-
train_state['validation_goal_reached'] = (
441-
workload.has_reached_validation_target(latest_eval_result) or
442-
train_state['validation_goal_reached'])
443-
train_state['test_goal_reached'] = (
444-
workload.has_reached_test_target(latest_eval_result) or
445-
train_state['test_goal_reached'])
446-
goals_reached = (
447-
train_state['validation_goal_reached'] and
448-
train_state['test_goal_reached'])
449-
# Save last eval time.
450-
eval_end_time = get_time()
451-
train_state['last_eval_time'] = eval_end_time
452-
453-
# Accumulate eval time.
454-
train_state[
455-
'accumulated_eval_time'] += eval_end_time - eval_start_time
456-
457-
# Add times to eval results for logging.
458-
latest_eval_result['score'] = (
459-
train_state['accumulated_submission_time'])
460-
latest_eval_result[
461-
'total_duration'] = eval_end_time - global_start_time
462-
latest_eval_result['accumulated_submission_time'] = train_state[
463-
'accumulated_submission_time']
464-
latest_eval_result['accumulated_eval_time'] = train_state[
465-
'accumulated_eval_time']
466-
latest_eval_result['accumulated_logging_time'] = train_state[
467-
'accumulated_logging_time']
468-
time_since_start = latest_eval_result['total_duration']
469-
logging.info(f'Time since start: {time_since_start:.2f}s, '
470-
f'\tStep: {global_step}, \t{latest_eval_result}')
471-
eval_results.append((global_step, latest_eval_result))
472-
473-
logging_start_time = get_time()
474-
475-
if log_dir is not None and RANK == 0:
476-
metrics_logger.append_scalar_metrics(
477-
latest_eval_result,
478-
global_step=global_step,
479-
preemption_count=preemption_count,
480-
is_eval=True,
481-
)
482-
if save_checkpoints:
483-
checkpoint_utils.save_checkpoint(
484-
framework=FLAGS.framework,
485-
optimizer_state=optimizer_state,
486-
model_params=model_params,
487-
model_state=model_state,
488-
train_state=train_state,
489-
eval_results=eval_results,
490-
global_step=global_step,
491-
preemption_count=preemption_count,
492-
checkpoint_dir=log_dir,
493-
save_intermediate_checkpoints=FLAGS
494-
.save_intermediate_checkpoints)
495-
496-
logging_end_time = get_time()
497-
train_state['accumulated_logging_time'] += (
498-
logging_end_time - logging_start_time)
499-
500-
_reset_cuda_mem()
501-
502-
except RuntimeError as e:
503-
logging.exception(f'Eval step {global_step} error.\n')
504-
if 'out of memory' in str(e):
505-
logging.warning(
506-
'Error: GPU out of memory during eval during step '
507-
f'{global_step}, error : {str(e)}.')
508-
_reset_cuda_mem()
509-
510-
train_state['last_step_end_time'] = get_time()
399+
# # Check if submission is eligible for an untimed eval.
400+
# if ((train_step_end_time - train_state['last_eval_time']) >=
401+
# workload.eval_period_time_sec or train_state['training_complete']):
402+
403+
# # Prepare for evaluation (timed).
404+
# if prepare_for_eval is not None:
405+
406+
# with profiler.profile('Prepare for eval'):
407+
# del batch
408+
# prepare_for_eval_start_time = get_time()
409+
# optimizer_state, model_params, model_state = prepare_for_eval(
410+
# workload=workload,
411+
# current_param_container=model_params,
412+
# current_params_types=workload.model_params_types,
413+
# model_state=model_state,
414+
# hyperparameters=hyperparameters,
415+
# loss_type=workload.loss_type,
416+
# optimizer_state=optimizer_state,
417+
# eval_results=eval_results,
418+
# global_step=global_step,
419+
# rng=prep_eval_rng)
420+
# prepare_for_eval_end_time = get_time()
421+
422+
# # Update sumbission time.
423+
# train_state['accumulated_submission_time'] += (
424+
# prepare_for_eval_end_time - prepare_for_eval_start_time)
425+
426+
# # Check if time is remaining,
427+
# # use 1.5x the runtime budget for the self-tuning ruleset.
428+
# max_allowed_runtime_sec = (
429+
# workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external'
430+
# else 1.5 * workload.max_allowed_runtime_sec)
431+
# train_state['is_time_remaining'] = (
432+
# train_state['accumulated_submission_time'] < max_allowed_runtime_sec)
433+
434+
# # Eval if time is remaining (untimed).
435+
# if train_state['is_time_remaining']:
436+
437+
# with profiler.profile('Evaluation'):
438+
# _reset_cuda_mem()
439+
440+
# try:
441+
# eval_start_time = get_time()
442+
# latest_eval_result = workload.eval_model(global_eval_batch_size,
443+
# model_params,
444+
# model_state,
445+
# eval_rng,
446+
# data_dir,
447+
# imagenet_v2_data_dir,
448+
# global_step)
449+
# # Check if targets reached.
450+
# # Note that this is one of the stopping conditions for the length of
451+
# # a training run. To score the run we only consider the time
452+
# # to validation target retrospectively.
453+
# train_state['validation_goal_reached'] = (
454+
# workload.has_reached_validation_target(latest_eval_result) or
455+
# train_state['validation_goal_reached'])
456+
# train_state['test_goal_reached'] = (
457+
# workload.has_reached_test_target(latest_eval_result) or
458+
# train_state['test_goal_reached'])
459+
# goals_reached = (
460+
# train_state['validation_goal_reached'] and
461+
# train_state['test_goal_reached'])
462+
# # Save last eval time.
463+
# eval_end_time = get_time()
464+
# train_state['last_eval_time'] = eval_end_time
465+
466+
# # Accumulate eval time.
467+
# train_state[
468+
# 'accumulated_eval_time'] += eval_end_time - eval_start_time
469+
470+
# # Add times to eval results for logging.
471+
# latest_eval_result['score'] = (
472+
# train_state['accumulated_submission_time'])
473+
# latest_eval_result[
474+
# 'total_duration'] = eval_end_time - global_start_time
475+
# latest_eval_result['accumulated_submission_time'] = train_state[
476+
# 'accumulated_submission_time']
477+
# latest_eval_result['accumulated_eval_time'] = train_state[
478+
# 'accumulated_eval_time']
479+
# latest_eval_result['accumulated_logging_time'] = train_state[
480+
# 'accumulated_logging_time']
481+
# time_since_start = latest_eval_result['total_duration']
482+
# logging.info(f'Time since start: {time_since_start:.2f}s, '
483+
# f'\tStep: {global_step}, \t{latest_eval_result}')
484+
# eval_results.append((global_step, latest_eval_result))
485+
486+
# logging_start_time = get_time()
487+
488+
# if log_dir is not None and RANK == 0:
489+
# metrics_logger.append_scalar_metrics(
490+
# latest_eval_result,
491+
# global_step=global_step,
492+
# preemption_count=preemption_count,
493+
# is_eval=True,
494+
# )
495+
# if save_checkpoints:
496+
# checkpoint_utils.save_checkpoint(
497+
# framework=FLAGS.framework,
498+
# optimizer_state=optimizer_state,
499+
# model_params=model_params,
500+
# model_state=model_state,
501+
# train_state=train_state,
502+
# eval_results=eval_results,
503+
# global_step=global_step,
504+
# preemption_count=preemption_count,
505+
# checkpoint_dir=log_dir,
506+
# save_intermediate_checkpoints=FLAGS
507+
# .save_intermediate_checkpoints)
508+
509+
# logging_end_time = get_time()
510+
# train_state['accumulated_logging_time'] += (
511+
# logging_end_time - logging_start_time)
512+
513+
# _reset_cuda_mem()
514+
515+
# except RuntimeError as e:
516+
# logging.exception(f'Eval step {global_step} error.\n')
517+
# if 'out of memory' in str(e):
518+
# logging.warning(
519+
# 'Error: GPU out of memory during eval during step '
520+
# f'{global_step}, error : {str(e)}.')
521+
# _reset_cuda_mem()
522+
523+
# train_state['last_step_end_time'] = get_time()
511524

512525
metrics = {'eval_results': eval_results, 'global_step': global_step}
513526

0 commit comments

Comments
 (0)