SIGN IN SIGN UP

fix(addmm): preserve float32 precision for alpha/beta in bf16/fp16 GEMM (#78960)

addmm incorrectly cast alpha/beta to tensor dtype (bf16/fp16) before
passing to cuBLAS, causing significant scalar precision loss
(e.g. alpha=2.9270 → bf16(2.921875), losing 0.17%).

Use MPTypeTrait<T>::Type pattern (same as baddbmm) to keep scalars in
float32 for half-precision types, matching PyTorch's opmath_type behavior.

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Z
Zhaowu Pan committed
6e096fef55b22b282af0ea92aacb32fbde9d433c
Parent: 39d7135
Committed by GitHub <noreply@github.com> on 5/14/2026, 7:18:44 AM