cublasZgemm{S/C}tridedBatched produces incorrect result for some non-contiguous cases when trans=’c’

🐛 Bug

A=torch.ones(1,2,1, device="cuda", dtype=torch.complex128)
B=torch.ones(1,3,1, device="cuda", dtype=torch.complex128).transpose(1,2).conj()

expected to produce tensor of size (1,2,3) filled with ones (note, imaginary part of inputs is 0), yet produces

tensor([[[1.+0.j, 0.+0.j, 0.+0.j],
         [1.+0.j, 0.+0.j, 0.+0.j]]], device='cuda:0', dtype=torch.complex128)

when cublasZgemm{S/C}tridedBatched is called with trans=’c’ (i.e. un-materialized conjugation) (PR: #59380)

cc @ngimel @ezyang @anjali411 @dylanbespalko @mruberry @Lezcano @nikitaved @csarofeen @ptrblck @xwang233 @jianyuh @pearu @heitorschueroff @walterddr @IvanYashchuk

1 thought on “cublasZgemm{S/C}tridedBatched produces incorrect result for some non-contiguous cases when trans=’c’

  1. I wrote a short C++ standalone program to work on this, but I can’t reproduce. I looked into PR #59380.

    at::cuda::blas::bgemm calls function here

    template <>
    void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
    // See Note [Writing Nondeterministic Operations]
    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
    cublasOperation_t opa = _cublasOpFromChar(transa);
    cublasOperation_t opb = _cublasOpFromChar(transb);
    _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
    handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
    lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta),
    reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches));

    where line 240 calls

    static void _cublasAdjustLdLevel3(
    char transa,
    char transb,
    int64_t m,
    int64_t n,
    int64_t k,
    int64_t* lda,
    int64_t* ldb,
    int64_t* ldc) {
    bool transa_ = ((transa == t) || (transa == T));
    bool transb_ = ((transb == t) || (transb == T));

    which also needs to be transposed for ‘c’ case.

    Note that lda, ldb needs to be changed if their matrices are transposed or conjugate transposed

Comments are closed.