[trainer] loss = NaN with label_smoothing and full-fp16 eval

It looks like our --label_smoothing_factor Trainer’s feature doesn’t handle fp16 well. It’s a problem with the deepspeed zero3 I’m integrating right now, since it evals in fp16, but also can be reproduced with the recently added --fp16_full_eval trainer option.

To reproduce:

export BS=16; rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0 python examples/seq2seq/run_seq2seq.py --model_name_or_path t5-small --output_dir output_dir --adam_eps 1e-06 --do_eval --evaluation_strategy=steps  --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro  --val_max_target_length 128 --warmup_steps 500  --max_val_samples 500 --dataset_name wmt16 --dataset_config "ro-en" --source_prefix "translate English to Romanian: " --fp16_full_eval
***** eval metrics *****
  eval_bleu                 = 24.1257
  eval_gen_len              =  39.554
  eval_loss                 =     nan
  eval_mem_cpu_alloc_delta  =    56MB
  eval_mem_cpu_peaked_delta =     0MB
  eval_mem_gpu_alloc_delta  =   116MB
  eval_mem_gpu_peaked_delta =   374MB
  eval_runtime              = 25.3246
  eval_samples              =     500
  eval_samples_per_second   =  19.744
  init_mem_cpu_alloc_delta  =     2MB
  init_mem_cpu_peaked_delta =     0MB
  init_mem_gpu_alloc_delta  =     0MB
  init_mem_gpu_peaked_delta =     0MB

If someone in the community would like to have a look at solving this puzzle, please refer to the discussion of this Issue.

Basically, we would like to try to find a way to perform label smoothing under full fp16 while finding a way to handle NaNs so that the final loss is not a NaN.

And for the reference value running the same script w/o --fp16_full_eval should give you the “golden” eval_loss – i.e. ideally it should be about the same with --fp16_full_eval (if possible that is).

Thank you!


