test_variant_consistency_jit tests fail on CPU for min & max when dtype is bfloat16 & dim argument is passed

🐛 Bug

Currently, no OpInfos exist for min & max.
I added one for testing in #51244,
but the following JIT-related tests fail although their respective CPU eager & CUDA counterparts pass:

In #51244, these test names correspond to the dim argument being passed in a call to min & max.

TestCommonCPU.test_variant_consistency_jit_min_reduction_with_dim_cpu_bfloat16
TestCommonCPU.test_variant_consistency_jit_max_reduction_with_dim_cpu_bfloat16.

Failure causes

Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 295, in instantiated_test
    raise rte
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 290, in instantiated_test
    result = test_fn(self, *args)
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 268, in test_wrapper
    return test(*args, **kwargs)
  File "test_ops.py", line 303, in test_variant_consistency_jit
    no_grad=not test_backward)
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_jit.py", line 77, in check_against_reference
    allow_unused=allow_unused)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 225, in grad
    inputs, allow_unused, accumulate_grad=False)
RuntimeError: "method_name" not implemented for 'BFloat16'

The call-flow doesn’t even reach the dispatcher. The exception seems to be thrown by pybind11:

Thread 1 "python3" hit Catchpoint 1 (exception thrown), 0x00007fffd8dd7762 in __cxa_throw () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
(gdb) bt
#0  0x00007fffd8dd7762 in __cxa_throw () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#1  0x00007fffcc9b8501 in pybind11::make_iterator<(pybind11::return_value_policy)6, torch::jit::Value* const*, torch::jit::Value* const*, torch::jit::Value* const&>(torch::jit::Value* const*, torch::jit::Value* const*)::{lambda(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&)#2}::operator()(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&) const (this=0x6e20b938, s=...)
    at /home/pytorch/third_party/pybind11/include/pybind11/pybind11.h:1908
#2  0x00007fffcc9fdcd9 in pybind11::detail::argument_loader<pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&>::call_impl<torch::jit::Value* const&, pybind11::make_iterator<(pybind11::return_value_policy)6, torch::jit::Value* const*, torch::jit::Value* const*, torch::jit::Value* const&>(torch::jit::Value* const*, torch::jit::Value* const*)::{lambda(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&)#2}&, 0ul, pybind11::detail::void_type>(pybind11::make_iterator<(pybind11::return_value_policy)6, torch::jit::Value* const*, torch::jit::Value* const*, torch::jit::Value* const&>(torch::jit::Value* const*, torch::jit::Value* const*)::{lambda(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&)#2}&, std::integer_sequence<unsigned long, 0ul>, pybind11::detail::void_type&&) && (
    this=0x7fffffffa0f0, f=...) at /home/pytorch/third_party/pybind11/include/pybind11/cast.h:2042
#3  0x00007fffcc9f0b6e in pybind11::detail::argument_loader<pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&>::call<torch::jit::Value* const&, pybind11::detail::void_type, pybind11::make_iterator<(pybind11::return_value_policy)6, torch::jit::Value* const*, torch::jit::Value* const*, torch::jit::Value* const&>(torch::jit::Value* const*, torch::jit::Value* const*)::{lambda(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&)#2}&>(pybind11::make_iterator<(pybind11::return_value_policy)6, torch::jit::Value* const*, torch::jit::Value* const*, torch::jit::Value* const&>(torch::jit::Value* const*, torch::jit::Value* const*)::{lambda(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&)#2}&) && (this=0x7fffffffa0f0, f=...)
    at /home/pytorch/third_party/pybind11/include/pybind11/cast.h:2014
#4  0x00007fffcc9e6937 in pybind11::cpp_function::initialize<pybind11::make_iterator<(pybind11::return_value_policy)6, torch::jit::Value* const*, torch::jit::Value* const*, torch::jit::Value* const&>(torch::jit::Value* const*, torch::jit::Value* const*)::{lambda(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&)#2}, torch::jit::Value* const&, pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::return_value_policy>(pybind11::make_iterator<(pybind11::return_value_policy)6, torch::jit::Value* const*, torch::jit::Value* const*, torch::jit::Value* const&>(torch::jit::Value* const*, torch::jit::Value* const*)::{lambda(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&)#2}&&, torch::jit::Value* const& (*)(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::return_value_policy const&)::{lambda(pybind11::detail::function_call&)#3}::operator()(pybind11::detail::function_call) const (this=0x0, call=...) at /home/pytorch/third_party/pybind11/include/pybind11/pybind11.h:192
#5  0x00007fffcc9e69a2 in pybind11::cpp_function::initialize<pybind11::make_iterator<(pybind11::return_value_policy)6, torch::jit::Value* const*, torch::jit::Value* const*, torch::jit::Value* const&>(torch::jit::Value* const*, torch::jit::Value* const*)::{lambda(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&)#2}, torch::jit::Value* const&, pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::return_value_policy>(pybind11::make_iterator<(pybind11::return_value_policy)6, torch::jit::Value* const*, torch::jit::Value* const*, torch::jit::Value* const&>(torch::jit::Value* const*, torch::jit::Value* const*)::{lambda(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&)#2}&&, torch::jit::Value* const& (*)(pybind11::detail::iterator_state<torch::jit::Value* const*, torch::jit::Value* const*, false, (pybind11::return_value_policy)6>&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::return_value_policy const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call) () at /home/pytorch/third_party/pybind11/include/pybind11/pybind11.h:170
#6  0x00007fffcc21169e in pybind11::cpp_function::dispatcher (self=0x7ffec9710330, args_in=0x7ffec9710820, kwargs_in=0x0)
    at /home/pytorch/third_party/pybind11/include/pybind11/pybind11.h:767
#7  0x00000000005f4249 in PyCFunction_Call ()
#8  0x00000000005f46d6 in _PyObject_MakeTpCall ()

To Reproduce

  1. The tests can be reproduced with the current code by merely adding OpInfos for min & max in common_methods_invocations.py. For example,

    OpInfo(‘max’,
    op=torch.max,
    variant_test_name=‘binary’,
    dtypes=all_types_and(torch.float16, torch.bfloat16),
    dtypesIfCPU=all_types_and(torch.float16, torch.bfloat16, torch.bool),
    dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16, torch.bool),
    test_inplace_grad=False,
    sample_inputs_func=sample_inputs_max_min_binary,
    assert_autodiffed=True,
    supports_tensor_out=True),
    OpInfo(‘max’,
    op=torch.max,
    variant_test_name=‘reduction_with_dim’,
    dtypes=all_types_and(torch.float16, torch.bfloat16),
    dtypesIfCPU=all_types_and(torch.float16, torch.bfloat16),
    dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16, torch.bool),
    test_inplace_grad=False,
    sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
    skips=(
    # Skip right now as it fails due to a pybind error
    SkipInfo(‘TestCommon’, ‘test_variant_consistency_jit’,
    device_type=‘cpu’, dtypes=[torch.bfloat16]),)),
    OpInfo(‘max’,
    op=torch.max,
    variant_test_name=‘reduction_no_dim’,
    dtypes=all_types_and(torch.float16, torch.bfloat16),
    dtypesIfCPU=all_types_and(torch.float16, torch.bfloat16),
    dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16, torch.bool),
    test_inplace_grad=False,
    sample_inputs_func=sample_inputs_max_min_reduction_no_dim,),
    OpInfo(‘min’,
    op=torch.min,
    variant_test_name=‘binary’,
    dtypes=all_types_and(torch.float16, torch.bfloat16),
    dtypesIfCPU=all_types_and(torch.float16, torch.bfloat16, torch.bool),
    dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16, torch.bool),
    test_inplace_grad=False,
    sample_inputs_func=sample_inputs_max_min_binary,
    assert_autodiffed=True,
    supports_tensor_out=True,),
    OpInfo(‘min’,
    op=torch.min,
    variant_test_name=‘reduction_with_dim’,
    dtypes=all_types_and(torch.float16, torch.bfloat16),
    dtypesIfCPU=all_types_and(torch.float16, torch.bfloat16),
    dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16, torch.bool),
    test_inplace_grad=False,
    sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
    skips=(
    # Skip right now as it fails due to a pybind error
    SkipInfo(‘TestCommon’, ‘test_variant_consistency_jit’,
    device_type=‘cpu’, dtypes=[torch.bfloat16]),)),
    OpInfo(‘min’,
    op=torch.min,
    variant_test_name=‘reduction_no_dim’,
    dtypes=all_types_and(torch.float16, torch.bfloat16),
    dtypesIfCPU=all_types_and(torch.float16, torch.bfloat16),
    dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16, torch.bool),
    test_inplace_grad=False,
    sample_inputs_func=sample_inputs_max_min_reduction_no_dim,),
  2. Add some sample inputs, such as:

    def sample_inputs_max_min_binary(op_info, device, dtype, requires_grad):
    inputs = []
    args_for_binary_op = (
    ((S, S, S), (S, S, S),),
    ((S, S, S), (S,),),
    ((S,), (S, S, S),),
    ((S, 1, S), (S, S),),
    ((), (),),
    ((S, S, S), (),),
    ((), (S, S, S),),
    )
    inputs = list((SampleInput(make_tensor(input_tensor, device, dtype,
    low=None, high=None,
    requires_grad=requires_grad),
    args=(make_tensor(other_tensor, device, dtype,
    low=None, high=None,
    requires_grad=requires_grad),),))
    for input_tensor, other_tensor in args_for_binary_op)
    return inputs
    def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad):
    inputs = []
    args_for_reduction_with_dim = (
    ((S, S, S), (1,),),
    ((S, S, S), (1, True, ),),
    ((), (0,),),
    ((), (0, True,),),
    )
    inputs = list((SampleInput(make_tensor(input_tensor, device, dtype,
    low=None, high=None,
    requires_grad=requires_grad),
    args=args,))
    for input_tensor, args in args_for_reduction_with_dim)
    return inputs
    def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad):
    inputs = []
    inputs.append(SampleInput(make_tensor((S, S, S), device, dtype,
    low=None, high=None,
    requires_grad=requires_grad),))
    inputs.append(SampleInput(make_tensor((), device, dtype,
    low=None, high=None,
    requires_grad=requires_grad),))
    return inputs
  3. Run the aforementioned tests.

Expected behavior

JIT tests should pass, as their eager counterparts do.

Environment

CI, as well as locally:

PyTorch version: 1.8.0a0+4757e71
Is debug build: True
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)

cc @gmagogsfm

1 possible answer(s) on “test_variant_consistency_jit tests fail on CPU for min & max when dtype is bfloat16 & dim argument is passed

  1. @imaginary-person Thanks for the update! Yes there are known problems with the BFloat16 operator in JIT, it’s not fully supported. We have an issue #48978 (I see you’ve commented on it already) to figure out a way to clean up the test. You should be good to skip the BFloat16 dtype for the test_variant_consistency_jit tests without worry