@@ -372,6 +372,19 @@ def train_once(
372
372
rng = update_rng ,
373
373
** ({'train_state' : MappingProxyType (train_state )}
374
374
if needs_train_state else {}))
375
+ if FLAGS .framework == "pytorch" and global_step % 100 == 0 :
376
+ if global_step > 1000 :
377
+ import torch .distributed as dist
378
+ import sys
379
+ dist .destroy_process_group ()
380
+ sys .exit (0 )
381
+ # Save the PyTorch weights to a file every 100 steps.
382
+ date_ = datetime .date .today ().strftime ('%Y-%m-%d' )
383
+ file_name = os .path .join (
384
+ log_dir , f'/results/schedule_free_pytorch_weights/{ workload_name } _{ date_ } _after_{ global_step } _steps.pth' )
385
+ logging .info (f'Saving PyTorch weights to { file_name } ' )
386
+ torch .save (model_params .module .state_dict (), file_name )
387
+
375
388
except spec .TrainingCompleteError :
376
389
train_state ['training_complete' ] = True
377
390
global_step += 1
@@ -383,131 +396,131 @@ def train_once(
383
396
train_state ['accumulated_submission_time' ] += (
384
397
train_step_end_time - train_state ['last_step_end_time' ])
385
398
386
- # Check if submission is eligible for an untimed eval.
387
- if ((train_step_end_time - train_state ['last_eval_time' ]) >=
388
- workload .eval_period_time_sec or train_state ['training_complete' ]):
389
-
390
- # Prepare for evaluation (timed).
391
- if prepare_for_eval is not None :
392
-
393
- with profiler .profile ('Prepare for eval' ):
394
- del batch
395
- prepare_for_eval_start_time = get_time ()
396
- optimizer_state , model_params , model_state = prepare_for_eval (
397
- workload = workload ,
398
- current_param_container = model_params ,
399
- current_params_types = workload .model_params_types ,
400
- model_state = model_state ,
401
- hyperparameters = hyperparameters ,
402
- loss_type = workload .loss_type ,
403
- optimizer_state = optimizer_state ,
404
- eval_results = eval_results ,
405
- global_step = global_step ,
406
- rng = prep_eval_rng )
407
- prepare_for_eval_end_time = get_time ()
408
-
409
- # Update sumbission time.
410
- train_state ['accumulated_submission_time' ] += (
411
- prepare_for_eval_end_time - prepare_for_eval_start_time )
412
-
413
- # Check if time is remaining,
414
- # use 1.5x the runtime budget for the self-tuning ruleset.
415
- max_allowed_runtime_sec = (
416
- workload .max_allowed_runtime_sec if FLAGS .tuning_ruleset == 'external'
417
- else 1.5 * workload .max_allowed_runtime_sec )
418
- train_state ['is_time_remaining' ] = (
419
- train_state ['accumulated_submission_time' ] < max_allowed_runtime_sec )
420
-
421
- # Eval if time is remaining (untimed).
422
- if train_state ['is_time_remaining' ]:
423
-
424
- with profiler .profile ('Evaluation' ):
425
- _reset_cuda_mem ()
426
-
427
- try :
428
- eval_start_time = get_time ()
429
- latest_eval_result = workload .eval_model (global_eval_batch_size ,
430
- model_params ,
431
- model_state ,
432
- eval_rng ,
433
- data_dir ,
434
- imagenet_v2_data_dir ,
435
- global_step )
436
- # Check if targets reached.
437
- # Note that this is one of the stopping conditions for the length of
438
- # a training run. To score the run we only consider the time
439
- # to validation target retrospectively.
440
- train_state ['validation_goal_reached' ] = (
441
- workload .has_reached_validation_target (latest_eval_result ) or
442
- train_state ['validation_goal_reached' ])
443
- train_state ['test_goal_reached' ] = (
444
- workload .has_reached_test_target (latest_eval_result ) or
445
- train_state ['test_goal_reached' ])
446
- goals_reached = (
447
- train_state ['validation_goal_reached' ] and
448
- train_state ['test_goal_reached' ])
449
- # Save last eval time.
450
- eval_end_time = get_time ()
451
- train_state ['last_eval_time' ] = eval_end_time
452
-
453
- # Accumulate eval time.
454
- train_state [
455
- 'accumulated_eval_time' ] += eval_end_time - eval_start_time
456
-
457
- # Add times to eval results for logging.
458
- latest_eval_result ['score' ] = (
459
- train_state ['accumulated_submission_time' ])
460
- latest_eval_result [
461
- 'total_duration' ] = eval_end_time - global_start_time
462
- latest_eval_result ['accumulated_submission_time' ] = train_state [
463
- 'accumulated_submission_time' ]
464
- latest_eval_result ['accumulated_eval_time' ] = train_state [
465
- 'accumulated_eval_time' ]
466
- latest_eval_result ['accumulated_logging_time' ] = train_state [
467
- 'accumulated_logging_time' ]
468
- time_since_start = latest_eval_result ['total_duration' ]
469
- logging .info (f'Time since start: { time_since_start :.2f} s, '
470
- f'\t Step: { global_step } , \t { latest_eval_result } ' )
471
- eval_results .append ((global_step , latest_eval_result ))
472
-
473
- logging_start_time = get_time ()
474
-
475
- if log_dir is not None and RANK == 0 :
476
- metrics_logger .append_scalar_metrics (
477
- latest_eval_result ,
478
- global_step = global_step ,
479
- preemption_count = preemption_count ,
480
- is_eval = True ,
481
- )
482
- if save_checkpoints :
483
- checkpoint_utils .save_checkpoint (
484
- framework = FLAGS .framework ,
485
- optimizer_state = optimizer_state ,
486
- model_params = model_params ,
487
- model_state = model_state ,
488
- train_state = train_state ,
489
- eval_results = eval_results ,
490
- global_step = global_step ,
491
- preemption_count = preemption_count ,
492
- checkpoint_dir = log_dir ,
493
- save_intermediate_checkpoints = FLAGS
494
- .save_intermediate_checkpoints )
495
-
496
- logging_end_time = get_time ()
497
- train_state ['accumulated_logging_time' ] += (
498
- logging_end_time - logging_start_time )
499
-
500
- _reset_cuda_mem ()
501
-
502
- except RuntimeError as e :
503
- logging .exception (f'Eval step { global_step } error.\n ' )
504
- if 'out of memory' in str (e ):
505
- logging .warning (
506
- 'Error: GPU out of memory during eval during step '
507
- f'{ global_step } , error : { str (e )} .' )
508
- _reset_cuda_mem ()
509
-
510
- train_state ['last_step_end_time' ] = get_time ()
399
+ # # Check if submission is eligible for an untimed eval.
400
+ # if ((train_step_end_time - train_state['last_eval_time']) >=
401
+ # workload.eval_period_time_sec or train_state['training_complete']):
402
+
403
+ # # Prepare for evaluation (timed).
404
+ # if prepare_for_eval is not None:
405
+
406
+ # with profiler.profile('Prepare for eval'):
407
+ # del batch
408
+ # prepare_for_eval_start_time = get_time()
409
+ # optimizer_state, model_params, model_state = prepare_for_eval(
410
+ # workload=workload,
411
+ # current_param_container=model_params,
412
+ # current_params_types=workload.model_params_types,
413
+ # model_state=model_state,
414
+ # hyperparameters=hyperparameters,
415
+ # loss_type=workload.loss_type,
416
+ # optimizer_state=optimizer_state,
417
+ # eval_results=eval_results,
418
+ # global_step=global_step,
419
+ # rng=prep_eval_rng)
420
+ # prepare_for_eval_end_time = get_time()
421
+
422
+ # # Update sumbission time.
423
+ # train_state['accumulated_submission_time'] += (
424
+ # prepare_for_eval_end_time - prepare_for_eval_start_time)
425
+
426
+ # # Check if time is remaining,
427
+ # # use 1.5x the runtime budget for the self-tuning ruleset.
428
+ # max_allowed_runtime_sec = (
429
+ # workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external'
430
+ # else 1.5 * workload.max_allowed_runtime_sec)
431
+ # train_state['is_time_remaining'] = (
432
+ # train_state['accumulated_submission_time'] < max_allowed_runtime_sec)
433
+
434
+ # # Eval if time is remaining (untimed).
435
+ # if train_state['is_time_remaining']:
436
+
437
+ # with profiler.profile('Evaluation'):
438
+ # _reset_cuda_mem()
439
+
440
+ # try:
441
+ # eval_start_time = get_time()
442
+ # latest_eval_result = workload.eval_model(global_eval_batch_size,
443
+ # model_params,
444
+ # model_state,
445
+ # eval_rng,
446
+ # data_dir,
447
+ # imagenet_v2_data_dir,
448
+ # global_step)
449
+ # # Check if targets reached.
450
+ # # Note that this is one of the stopping conditions for the length of
451
+ # # a training run. To score the run we only consider the time
452
+ # # to validation target retrospectively.
453
+ # train_state['validation_goal_reached'] = (
454
+ # workload.has_reached_validation_target(latest_eval_result) or
455
+ # train_state['validation_goal_reached'])
456
+ # train_state['test_goal_reached'] = (
457
+ # workload.has_reached_test_target(latest_eval_result) or
458
+ # train_state['test_goal_reached'])
459
+ # goals_reached = (
460
+ # train_state['validation_goal_reached'] and
461
+ # train_state['test_goal_reached'])
462
+ # # Save last eval time.
463
+ # eval_end_time = get_time()
464
+ # train_state['last_eval_time'] = eval_end_time
465
+
466
+ # # Accumulate eval time.
467
+ # train_state[
468
+ # 'accumulated_eval_time'] += eval_end_time - eval_start_time
469
+
470
+ # # Add times to eval results for logging.
471
+ # latest_eval_result['score'] = (
472
+ # train_state['accumulated_submission_time'])
473
+ # latest_eval_result[
474
+ # 'total_duration'] = eval_end_time - global_start_time
475
+ # latest_eval_result['accumulated_submission_time'] = train_state[
476
+ # 'accumulated_submission_time']
477
+ # latest_eval_result['accumulated_eval_time'] = train_state[
478
+ # 'accumulated_eval_time']
479
+ # latest_eval_result['accumulated_logging_time'] = train_state[
480
+ # 'accumulated_logging_time']
481
+ # time_since_start = latest_eval_result['total_duration']
482
+ # logging.info(f'Time since start: {time_since_start:.2f}s, '
483
+ # f'\tStep: {global_step}, \t{latest_eval_result}')
484
+ # eval_results.append((global_step, latest_eval_result))
485
+
486
+ # logging_start_time = get_time()
487
+
488
+ # if log_dir is not None and RANK == 0:
489
+ # metrics_logger.append_scalar_metrics(
490
+ # latest_eval_result,
491
+ # global_step=global_step,
492
+ # preemption_count=preemption_count,
493
+ # is_eval=True,
494
+ # )
495
+ # if save_checkpoints:
496
+ # checkpoint_utils.save_checkpoint(
497
+ # framework=FLAGS.framework,
498
+ # optimizer_state=optimizer_state,
499
+ # model_params=model_params,
500
+ # model_state=model_state,
501
+ # train_state=train_state,
502
+ # eval_results=eval_results,
503
+ # global_step=global_step,
504
+ # preemption_count=preemption_count,
505
+ # checkpoint_dir=log_dir,
506
+ # save_intermediate_checkpoints=FLAGS
507
+ # .save_intermediate_checkpoints)
508
+
509
+ # logging_end_time = get_time()
510
+ # train_state['accumulated_logging_time'] += (
511
+ # logging_end_time - logging_start_time)
512
+
513
+ # _reset_cuda_mem()
514
+
515
+ # except RuntimeError as e:
516
+ # logging.exception(f'Eval step {global_step} error.\n')
517
+ # if 'out of memory' in str(e):
518
+ # logging.warning(
519
+ # 'Error: GPU out of memory during eval during step '
520
+ # f'{global_step}, error : {str(e)}.')
521
+ # _reset_cuda_mem()
522
+
523
+ # train_state['last_step_end_time'] = get_time()
511
524
512
525
metrics = {'eval_results' : eval_results , 'global_step' : global_step }
513
526
0 commit comments