I'm using pandera for schema validation. Initially, I wanted to add few simple checks and validate the raw data. And later I wanted to enhance the schema to do more things and drop more invalid rows. But during 1st phase itself, I'm not getting what I expect.
Below is my sample code -
In [4]: tmp_df: pd.DataFrame = pd.DataFrame({"colA": [0.0, np.NAN, 20.0, 30, 0.1, 0.2, 0.3, 0.5], "colB": [-1, 200, 300, np.NAN, 800, -2, 100, 5], "target": [500, 550, np.NAN, 600, 450, 0, 800, 1000]})
In [5]: tmp_df
Out[5]:
colA colB target
0 0.0 -1.0 500.0
1 NaN 200.0 550.0
2 20.0 300.0 NaN
3 30.0 NaN 600.0
4 0.1 800.0 450.0
5 0.2 -2.0 0.0
6 0.3 100.0 800.0
7 0.5 5.0 1000.0
In [7]: schema_cols: List[Tuple[str, Type ]] = [("colA", float), ("colB", int), ("target", float)]
In [8]: schema_cols
Out[8]: [('colA', float), ('colB', int), ('target', float)]
In [11]: schema_dict: Dict[str, pa.Column] = {tup[0]: pa.Column(dtype=tup[1], nullable=True, required=True, name=tup[0], drop_invalid_rows=True) for tup in schema_cols}
In [12]: schema_dict
Out[12]:
{'colA': <Schema Column(name=colA, type=DataType(float64))>,
'colB': <Schema Column(name=colB, type=DataType(int64))>,
'target': <Schema Column(name=target, type=DataType(float64))>}
# in actual code, there are other types but I want add checks only for float and int
In [15]: for key, val in schema_dict.items():
...: if val.dtype.type == "float64" or val.dtype.type == "int64":
...: if val.checks:
...: val.checks.append(pa.Check.ge(min_value=0))
...: else:
...:
...: val.checks = [pa.Check.ge(min_value=0)]
...:
In [16]: schema: pa.DataFrameSchema = pa.DataFrameSchema(columns=schema_dict, drop_invalid_rows=True, coerce=True, strict="filter", add_missing_columns=False,)
In [17]: schema
Out[17]: <Schema DataFrameSchema(columns={'colA': <Schema Column(name=colA, type=DataType(float64))>, 'colB': <Schema Column(name=colB, type=DataType(int64))>, 'target': <Schema Column(name=target, type=DataType(float64))>}, checks=[], index=None, coerce=True, dtype=None, strict=filter, name=None, ordered=False, unique_column_names=Falsemetadata='None, unique_column_names=False, add_missing_columns=False)>
In [18]: for key, val in schema.columns.items():
...: print(val.dtype.type)
...:
float64
int64
float64
In [19]: schema.validate(check_obj=tmp_df, lazy=True)
Out[19]:
colA colB target
0 0.0 -1.0 500.0
1 NaN 200.0 550.0
2 20.0 300.0 NaN
4 0.1 800.0 450.0
5 0.2 -2.0 0.0
6 0.3 100.0 800.0
7 0.5 5.0 1000.0
Here, NaN row for only colB (int type) gets dropped but not the other ones. Since nullable is set true for every column, I was expecting either for none of the columns rows will get dropped or it'll happen for all the columns.
What is wrong with my usage?