Batch Normalization fails as kernel constraint for Conv layers when using mixed precision

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
  • TensorFlow installed from (source or binary): source
  • TensorFlow version (use command below): 2.4.1
  • Python version: 3.7

Describe the current behavior
When using tf.keras.layers.BatchNormalization() as a constraint in a conv layer using mixed precision, the model cannot train

Describe the expected behavior
Using tf.keras.layers.BatchNormalization() as a constraint in a conv layer behaves the same regardless of using mixed precision or not.

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Colab link.

I’ve included a few notes in comments to show that this issue is isolated to conv layers when using mixed precision.

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

It looks like the issue is in the loss_scale_optimizer.

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:805 train_function  *
        return step_function(self, iterator)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:795 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica
        return fn(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:788 run_step  **
        outputs = model.train_step(data)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:757 train_step
        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:498 minimize
        return self.apply_gradients(grads_and_vars, name=name)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py:712 apply_gradients
        args=(grads_and_vars, name, experimental_aggregate_gradients))
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2941 merge_call
        return self._merge_call(merge_fn, args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2948 _merge_call
        return merge_fn(self._strategy, *args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py:745 _apply_gradients_cross_replica  **
        do_not_apply_fn)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/smart_cond.py:59 smart_cond
        name=name)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/deprecation.py:538 new_func
        return func(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/control_flow_ops.py:1180 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/cond_v2.py:89 cond_v2
        op_return_value=pred)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py:990 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py:732 apply_fn
        args=(grads, wrapped_vars, name, experimental_aggregate_gradients))
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica
        return fn(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py:755 _apply_gradients
        experimental_aggregate_gradients=experimental_aggregate_gradients)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:635 apply_gradients
        "name": name,
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2941 merge_call
        return self._merge_call(merge_fn, args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2948 _merge_call
        return merge_fn(self._strategy, *args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:683 _distributed_apply  **
        var, apply_grad_to_update_var, args=(grad,), group=False))
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2494 update
        return self._update(var, fn, args, kwargs, group)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3431 _update
        return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3437 _update_non_slot
        result = fn(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:661 apply_grad_to_update_var  **
        return var.assign(var.constraint(var))
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py:237 assign
        name, read_value)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py:209 _apply_assign_update
        assign_op = update_fn(value, use_locking, name, False)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/resource_variable_ops.py:882 assign
        value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/profiler/trace.py:163 wrapped
        return func(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:1509 convert_to_tensor
        (dtype.name, value.dtype.name, value))

    ValueError: Tensor conversion requested dtype float32 for Tensor with dtype float16: <tf.Tensor 'cond_1/SGD/SGD/update/batch_normalization_2/FusedBatchNormV3:0' shape=(3, 3, 1, 32) dtype=float16>

1 possible answer(s) on “Batch Normalization fails as kernel constraint for Conv layers when using mixed precision

  1. Layers are not intended to be passed as a kernel_constraint, and the weights of the layers will not be trained when this is done (I’m not sure what your use case is). But it typically will work as any callable can be passed. With mixed precision, weights are float32 but layer outputs are float16. The kernel_constraint output dtype must be the same as the weight dtype, which is why you must pass dtype='float32'.