Troubleshooting FP8 Conversion Discrepancy from Float32

46 Views Asked by At

I'm trying to convert a float32 to an fp8 float. There's some detail in the code that is causing a discrepancy in the results. I can't figure out what it is, could someone help me?

def float32_to_ofp8_and_back(value: float, encoding: str) -> float:
    import numpy as np
    
    # Determine the sign bit
    S = 0 if value >= 0 else 1
    value = abs(value)
    
    # Calculate the exponent and mantissa for normalization
    exponent = int(np.floor(np.log2(value))) if value != 0 else 0 
    mantissa = value / (2 ** exponent) - 1 if value != 0 else 0

    # Select encoding and set bias and mantissa bits
    if encoding == "E4M3":
        bias = 7
        exponent_bits = 4
        mantissa_bits = 3
    elif encoding == "E5M2":
        bias = 15
        exponent_bits = 5
        mantissa_bits = 2
    else:
        raise ValueError("Unsupported encoding")
    
    # Adjust the exponent and mantissa
    E = exponent + bias
    M = int(mantissa * (2 ** mantissa_bits))

    # Convert back to float
    if E == 0 and M > 0:  # Subnormal
        converted_value = (-1)**S * 2**(1-bias) * (0 + 2**(-mantissa_bits) * M)
    else:  # Normal
        converted_value = (-1)**S * 2**(E-bias) * (1 + 2**(-mantissa_bits) * M)
    
    # Return as float32 with reduced precision
    return np.float32(converted_value)

# Results are incorrect:
assert float32_to_ofp8_and_back(value = np.pi, encoding = "E4M3") == 3.25 # but returns 3
assert float32_to_ofp8_and_back(value = np.sqrt(7), encoding =  "E4M3") == 2.75 # but returns  2.5
assert float32_to_ofp8_and_back(value = np.sqrt(6), encoding = "E4M3") == 2.50 # but returns 2.25
assert float32_to_ofp8_and_back(value = np.sqrt(6), encoding = "E5M2") == 2.50 # but returns 2.00

# Results are correct:
assert float32_to_ofp8_and_back(value = np.pi, encoding = "E5M2") == 3.0 
assert float32_to_ofp8_and_back(value = np.sqrt(7), encoding =  "E5M2") == 2.50 

I am trying to test it against: square root of 7, square root of 6, and pi.

0

There are 0 best solutions below