I'm trying to convert a float32 to an fp8 float. There's some detail in the code that is causing a discrepancy in the results. I can't figure out what it is, could someone help me?
def float32_to_ofp8_and_back(value: float, encoding: str) -> float:
import numpy as np
# Determine the sign bit
S = 0 if value >= 0 else 1
value = abs(value)
# Calculate the exponent and mantissa for normalization
exponent = int(np.floor(np.log2(value))) if value != 0 else 0
mantissa = value / (2 ** exponent) - 1 if value != 0 else 0
# Select encoding and set bias and mantissa bits
if encoding == "E4M3":
bias = 7
exponent_bits = 4
mantissa_bits = 3
elif encoding == "E5M2":
bias = 15
exponent_bits = 5
mantissa_bits = 2
else:
raise ValueError("Unsupported encoding")
# Adjust the exponent and mantissa
E = exponent + bias
M = int(mantissa * (2 ** mantissa_bits))
# Convert back to float
if E == 0 and M > 0: # Subnormal
converted_value = (-1)**S * 2**(1-bias) * (0 + 2**(-mantissa_bits) * M)
else: # Normal
converted_value = (-1)**S * 2**(E-bias) * (1 + 2**(-mantissa_bits) * M)
# Return as float32 with reduced precision
return np.float32(converted_value)
# Results are incorrect:
assert float32_to_ofp8_and_back(value = np.pi, encoding = "E4M3") == 3.25 # but returns 3
assert float32_to_ofp8_and_back(value = np.sqrt(7), encoding = "E4M3") == 2.75 # but returns 2.5
assert float32_to_ofp8_and_back(value = np.sqrt(6), encoding = "E4M3") == 2.50 # but returns 2.25
assert float32_to_ofp8_and_back(value = np.sqrt(6), encoding = "E5M2") == 2.50 # but returns 2.00
# Results are correct:
assert float32_to_ofp8_and_back(value = np.pi, encoding = "E5M2") == 3.0
assert float32_to_ofp8_and_back(value = np.sqrt(7), encoding = "E5M2") == 2.50
I am trying to test it against: square root of 7, square root of 6, and pi.