From c31d0ea1d639ad8b4a7fba7e60ff70cb4f581581 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 3 Feb 2026 01:49:50 +0000 Subject: [PATCH] Fix exp2f_rcp to properly handle nan and 0xFE cases Signed-off-by: Kaining Zhong --- tests/cpp/test_common.h | 10 +++++++--- transformer_engine/common/util/ptx.cuh | 10 +++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 082677c978..5bb6400629 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -425,10 +425,14 @@ inline fp8e8m0 float_to_e8m0(float val) { } inline float exp2f_rcp(fp8e8m0 biased_exp) { - if (biased_exp == 0) { - return 1.0f; + int32_t int_val = 0; + if (biased_exp == 255) { + int_val = 0x7fffffff; + } else if (biased_exp == 254) { + int_val = 0x00400000; + } else { + int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127) } - int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127) float fp32_val = *reinterpret_cast(&int_val); return fp32_val; } diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 9bcf6e2289..5367d7e781 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -328,9 +328,13 @@ constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_EXPONENT_BIAS = 127; __device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { - return (biased_exp == 0) ? 1 - : __int_as_float((254 - biased_exp) - << FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) + // Handle the special case of NaN. + if (biased_exp == 255) return __int_as_float(0x7fffffff); + // Handle the special case where the unbiased exponent is 127, so the reciprocal is 2^-127 which needs the first bit of + // the mantissa to be 1, which can't be obtained by shifting `FP32_MANTISSA_BITS` bits to the left. + if (biased_exp == 254) return __int_as_float(0x00400000); + // Fast calculation when the unbiased exp is in [-126, 126], and only the exponent part is used to express the reciprocal. + return __int_as_float((254 - biased_exp) << FP32_MANTISSA_BITS); } __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {