Skip to content

ver 4.35.2 transformers.Trainer breaks CUDA AMP support #27760

@haixpham

Description

@haixpham

System Info

torch 2.1.1
transformers 4.35.2
CUDA 12.1

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

In the latest version of transformers (4.35.2), in Trainer.__init__(), there is no option for use_cuda_amp, and consequently, in Trainer.autocast_smart_context_manager(), torch.cuda.amp.autocast() is not invoked, which will lead to runtime error during loss backward.

This is the __init__() part of version 4.35.2, which does not enable autocast for both fp16 and bf16:

# Mixed precision setup
self.use_apex = False
self.use_cpu_amp = False

# Mixed precision setup for SageMaker Model Parallel
if is_sagemaker_mp_enabled():
    # BF16 + model parallelism in SageMaker: currently not supported, raise an error
    if args.bf16:
        raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")

    if IS_SAGEMAKER_MP_POST_1_10:
        # When there's mismatch between SMP config and trainer argument, use SMP config as truth
        if args.fp16 != smp.state.cfg.fp16:
            logger.warning(
                f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                f"but FP16 provided in trainer argument is {args.fp16}, "
                f"setting to {smp.state.cfg.fp16}"
            )
            args.fp16 = smp.state.cfg.fp16
    else:
        # smp < 1.10 does not support fp16 in trainer.
        if hasattr(smp.state.cfg, "fp16"):
            logger.warning(
                f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
            )
if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
    if args.device == torch.device("cpu"):
        if args.fp16:
            raise ValueError("Tried to use `fp16` but it is not supported on cpu")
        else:
            args.half_precision_backend = "cpu_amp"
    logger.info(f"Using {args.half_precision_backend} half precision backend")

if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
    # deepspeed and SageMaker Model Parallel manage their own half precision
    if args.half_precision_backend == "cpu_amp":
        self.use_cpu_amp = True
        self.amp_dtype = torch.bfloat16
    elif args.half_precision_backend == "apex":
        if not is_apex_available():
            raise ImportError(
                "Using FP16 with APEX but APEX is not installed, please refer to"
                " https://www.github.com/nvidia/apex."
            )
        self.use_apex = True

Expected behavior

the same part of version 4.28.1 (which I know it works):

# Mixed precision setup
self.use_apex = False
self.use_cuda_amp = False
self.use_cpu_amp = False

# Mixed precision setup for SageMaker Model Parallel
if is_sagemaker_mp_enabled():
    # BF16 + model parallelism in SageMaker: currently not supported, raise an error
    if args.bf16:
        raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")

    if IS_SAGEMAKER_MP_POST_1_10:
        # When there's mismatch between SMP config and trainer argument, use SMP config as truth
        if args.fp16 != smp.state.cfg.fp16:
            logger.warning(
                f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                f"but FP16 provided in trainer argument is {args.fp16},"
                f"setting to {smp.state.cfg.fp16}"
            )
            args.fp16 = smp.state.cfg.fp16
    else:
        # smp < 1.10 does not support fp16 in trainer.
        if hasattr(smp.state.cfg, "fp16"):
            logger.warning(
                f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
            )

if args.fp16 or args.bf16:
    if args.half_precision_backend == "auto":
        if args.device == torch.device("cpu"):
            if args.fp16:
                raise ValueError("Tried to use `fp16` but it is not supported on cpu")
            elif _is_native_cpu_amp_available:
                args.half_precision_backend = "cpu_amp"
            else:
                raise ValueError("Tried to use cpu amp but native cpu amp is not available")
        else:
            args.half_precision_backend = "cuda_amp"

    logger.info(f"Using {args.half_precision_backend} half precision backend")

#the following part is no longer needed because of the switch to accelerator, but add it just in case 
self.do_grad_scaling = False
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
    # deepspeed and SageMaker Model Parallel manage their own half precision
    if args.half_precision_backend == "cuda_amp":
        self.use_cuda_amp = True
        self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
        #  bf16 does not need grad scaling
        self.do_grad_scaling = self.amp_dtype == torch.float16
        if self.do_grad_scaling:
            if self.sharded_ddp is not None:
                self.scaler = ShardedGradScaler()
            elif self.fsdp is not None:
                from torch.distributed.fsdp.sharded_grad_scaler import (
                    ShardedGradScaler as FSDPShardedGradScaler,
                )

                self.scaler = FSDPShardedGradScaler()
            elif is_torch_tpu_available():
                from torch_xla.amp import GradScaler

                self.scaler = GradScaler()
            else:
                self.scaler = torch.cuda.amp.GradScaler()
    elif args.half_precision_backend == "cpu_amp":
        self.use_cpu_amp = True
        self.amp_dtype = torch.bfloat16
    else:
        if not is_apex_available():
            raise ImportError(
                "Using FP16 with APEX but APEX is not installed, please refer to"
                " https://www.github.com/nvidia/apex."
            )
        self.use_apex = True

My workaround is to derive Trainer and add the 4.28.1 part of code above after calling super().__init__(), and override autocast_smart_context_manager() with the same method in version 4.34.1.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions