Is the XLA-HLO different for each GPU device?

81 Views Asked by At

I got the below code for get the number of the flops in JAXPr.eqns. But when I run the code with different GPU devices, I get the different FLOPs numbers. For example, A100 80GB GPU, I received more FLOPs than RTX 3090 24GB GPU. Does the HLO Cost Module have cost considerations for specific devices?

new_inv = [inv for inv in eqn.invars if isinstance(inv, Var)]
jaxpr = Jaxpr([], new_inv, eqn.outvars, [eqn])
closed_jaxpr = ClosedJaxpr(jaxpr, [])
hlo_module = jaxpr_to_hlo("tmp", closed_jaxpr, [
    False,
] * len(jaxpr.invars)).get_module()

backend = xb.get_backend("gpu")
properties = xc._xla.hlo_module_cost_analysis(  # pylint: disable=protected-access
     backend, hlo_module)
return properties["flops"] if "flops" in properties else 0.0

I want the source codes of hlo cost module related to it or any hints of it.

0

There are 0 best solutions below