fix(addmm): preserve float32 precision for alpha/beta in bf16/fp16 GEMM (#78960)
addmm incorrectly cast alpha/beta to tensor dtype (bf16/fp16) before passing to cuBLAS, causing significant scalar precision loss (e.g. alpha=2.9270 → bf16(2.921875), losing 0.17%). Use MPTypeTrait<T>::Type pattern (same as baddbmm) to keep scalars in float32 for half-precision types, matching PyTorch's opmath_type behavior. Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Z
Zhaowu Pan committed
6e096fef55b22b282af0ea92aacb32fbde9d433c
Parent: 39d7135
Committed by GitHub <noreply@github.com>
on 5/14/2026, 7:18:44 AM