MyPy checking for torch.Tensor subclasses broken in 1.8

🐛 Bug

In PyTorch 1.7, it was possible to subclass torch.Tensor and get appropriate mypy type checking for class attributes of the custom type. In PyTorch 1.8, this is no longer possible; class attributes of the subclass resolve to torch.Tensor types.

To Reproduce

Steps to reproduce the behavior:

Code for both environments:

# file mypy_test.py
from typing import cast, Any, TYPE_CHECKING

import torch


class Tensor(torch.Tensor):
    """Type wrapper for torch.Tensor"""
    def __new__(cls, *args: Any, **kwargs: Any) -> "Tensor":
        return cast(Tensor, torch.as_tensor(*args, **kwargs))


GLOBAL: Tensor = Tensor([0, 0], dtype=torch.float32)


class Foo:
    a: Tensor = GLOBAL

    def foo(self) -> Tensor:
        return cast(Tensor, self.a)

if TYPE_CHECKING:
    reveal_type(GLOBAL)
    reveal_type(Foo.a)

if __name__ == "__main__":
    # this works
    bar: Tensor = GLOBAL
    bar = Foo().foo()
    # this doesn't work in 1.8 but does in 1.7
    bar = Foo().a

Running in a PyTorch 1.7 environment works as expected:

$ pip install torch==1.7.1
$ mypy mypy_test.py
mypy_test.py:22: note: Revealed type is 'mypy_test.Tensor'
mypy_test.py:23: note: Revealed type is 'mypy_test.Tensor'

Running in PyTorch 1.8 fails:

$ pip install torch==1.8.0
$ mypy mypy_test.py
mypy_test.py:22: note: Revealed type is 'mypy_test.Tensor'
mypy_test.py:23: note: Revealed type is 'torch.tensor.Tensor'
mypy_test.py:30: error: Incompatible types in assignment (expression has type "torch.tensor.Tensor", variable has type "mypy_test.Tensor")
Found 1 error in 1 file (checked 1 source file)

This error does not happen with other (non-torch) custom types used in the same pattern.

Expected behavior

MyPy doesn’t complain.

Environment

Tested under multiple:

PyTorch version: 1.8.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64)
GCC version: (GCC) 10.2.0
Clang version: 11.0.1
CMake version: version 3.19.3

Python version: 3.9 (64-bit runtime)
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] torch==1.8.0
[pip3] torchvision==0.8.2
[conda] Could not collect
PyTorch version: 1.8.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.19.5

Python version: 3.7 (64-bit runtime)
Is CUDA available: False
CUDA runtime version: 11.1.105
GPU models and configuration: GPU 0: Quadro RTX 6000
Nvidia driver version: 450.102.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.4.3
[pip3] numpy==1.16.1
[pip3] numpy-quaternion==2020.11.2.17.0.49
[pip3] numpy-stl==2.10.1
[pip3] torch==1.8.0+cu111
[pip3] torchgeometry==0.1.2
[pip3] torchvision==0.9.0+cu111
[conda] msgpack-numpy             0.4.4.3                  pypi_0    pypi
[conda] numpy                     1.16.1                   pypi_0    pypi
[conda] numpy-quaternion          2020.11.2.17.0.49          pypi_0    pypi
[conda] numpy-stl                 2.10.1                   pypi_0    pypi
[conda] torch                     1.8.0+cu111              pypi_0    pypi
[conda] torchgeometry             0.1.2                    pypi_0    pypi
[conda] torchvision               0.9.0+cu111              pypi_0    pypi

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411

1 possible answer(s) on “MyPy checking for torch.Tensor subclasses broken in 1.8