Function common_precision
takes two numpy arrays, say x
and y
. I want to make sure that they are in the same and the highest precision. It seems that relational comparison of dtypes does something to what I want, but:
- I don't know what it actually compares
- It thinks that
numpy.int64
<numpy.float16
, which I'm not sure if I agree
def common_precision(x, y):
if x.dtype > y.dtype:
y = y.astype(x.dtype)
else:
x = x.astype(y.dtype)
return (x, y)
Edited:
Thanks to kennytm's answer I found that NumPy's find_common_type
does exactly what I wanted.
def common_precision(self, x, y):
dtype = np.find_common_type([x.dtype, y.dtype], [])
if x.dtype != dtype: x = x.astype(dtype)
if y.dtype != dtype: y = y.astype(dtype)
return x, y
x.dtype > y.dtype
meansy.dtype
can be casted tox.dtype
(&& x.dtype != y.type
), so:float16 and int64 are simply incompatible. You could extract some information like:
and determine your comparison scheme based on this.