Converting RDD-based flattening logic to DataFrame approach in PySpark

41 Views Asked by At

I'm working on a PySpark application where I need to flatten nested JSON structures. Currently, my code relies on RDD transformations, but I'm transitioning to a DataFrame-centric approach due to limitations (mention the specific limitation, if any). Here's the relevant code snippet:

def flatten(df: DataFrame) -> DataFrame:
    spark = SparkSession.builder.getOrCreate()

    json_rdd = df.rdd.map(lambda x: x["JsonDocument"]).map(bson.json_util.loads).map(json_dict_conv)
    schema = json_rdd.map(parse_nested_schema).reduce(type_merger)
    schema = remove_null_from_schema(schema)
    json_rdd = json_rdd.map(lambda x: recast_nested_structure(x, schema))
    df_nested = spark.createDataFrame(json_rdd, schema)
    flattener = DataframeFlattener(df_nested)  # Assuming DataframeFlattener is a custom class
    df_flat = flattener.flatten()
    return df_flat

that uses the following methods:

def json_dict_conv(x):
    if isinstance(x, int):
        return int(x)
    if isinstance(x, bson.objectid.ObjectId):
        return str(x)
    if isinstance(x, dict):
        return {json_dict_conv(k): json_dict_conv(v) for k, v in x.items()}
    if isinstance(x, list):
        return [json_dict_conv(i) for i in x]
    if isinstance(x, uuid.UUID):
        return str(x)
    if isinstance(x, Decimal128):
        return x.to_decimal()
    return x

def parse_nested_schema(x):
    if isinstance(x, dict):
        if len(x) == 0:
            return T._infer_type(None)
        return T.StructType([
            T.StructField(k.replace("-", "_"), parse_nested_schema(v), True) for k, v in x.items()
        ])
    if isinstance(x, list):
        if len(x) == 0:
            return T._infer_type(x)
        return T.ArrayType(reduce(type_merger, [parse_nested_schema(i) for i in x]))
    return T._infer_type(x)

def type_merger(a, b):
    if isinstance(a, NullType):
        return b
    elif isinstance(b, NullType):
        return a

    t_a, t_b = type(a), type(b)

    if ArrayType in [t_a, t_b]:
        if not isinstance(a, ArrayType):
            a = ArrayType(a)
        if not isinstance(b, ArrayType):
            b = ArrayType(b)
        return ArrayType(type_merger(a.elementType, b.elementType), True)
    if StructType in [t_a, t_b]:
        if not isinstance(a, StructType):
            a = StructType([StructField("", a, True)])
        if not isinstance(b, StructType):
            b = StructType([StructField("", b, True)])
        b_datatypes = {f.name.lower(): f.dataType for f in b.fields}
        remap_names = {i.lower(): i for i in b.names}
        remap_names.update({i.lower(): i for i in a.names})
        fields = [StructField(remap_names[f.name.lower()], type_merger(f.dataType, b_datatypes.get(f.name.lower(), NullType())))
                  for f in a.fields]
        names = set([f.name.lower() for f in fields])
        for n, dtype in b_datatypes.items():
            if n not in names:
                fields.append(StructField(remap_names[n], dtype))
        return StructType(fields)
    elif t_a == t_b:
        return T._merge_type(a, b)
    return promote_type(a, b)

def remove_null_from_schema(schema):

    if isinstance(schema, NullType):
        return StringType()
    if isinstance(schema, StructType):
        return StructType([
            StructField(field.name, remove_null_from_schema(field.dataType), True) for field in schema.fields
        ])
    if isinstance(schema, ArrayType):
        return ArrayType(remove_null_from_schema(schema.elementType))
    return schema

def recast_nested_structure(data, schema):
    if data is None:
        return None
    if isinstance(schema, LongType):
        return int(data)
    if isinstance(schema, DoubleType):
        return float(data)
    if isinstance(schema, StringType):
        return str(data)
    if isinstance(schema, StructType):
        if not isinstance(data, dict):
            return {"": data}
        return {field.name: recast_nested_structure(data.get(field.name, None), field.dataType) for field in schema.fields}
    if isinstance(schema, ArrayType):
        if not isinstance(data, (list, tuple)):
            return [recast_nested_structure(data, schema.elementType)]
        return [recast_nested_structure(i, schema.elementType) for i in data]
    return data

I tried to reproduce the rdd map-reduce approach using pandas-udf to work on the whole column to use python's map-reduce using the same functions mentioned.

Thanks in advance, and any tips to improve the efficiency of the code are welcome.

0

There are 0 best solutions below