Extract all field names from nested dataclasses

207 Views Asked by At

I have a dataclass that contains within it another dataclass:

@dataclass
class A:
    var_1: str
    var_2: int

@dataclass
class B:
    var_3: float
    var_4: A

I would like to create a list of all field names for attributes that aren't dataclasses, and if the attribute is a dataclass the to list the attributes of that class, so in this case the output would be ['var_3', 'var_1', 'var_2'] I know it's possible to use dataclasses.fields to get the fields of a simple dataclass, but I can't work out how to recursively do it for nested dataclasses. Ideally I would like to be able to do it by just passing the class type B (in the same way you can pass the type to dataclasses.fields), rather than an instance of B. Is it possible to do this?

Thank you!

1

There are 1 best solutions below

0
On BEST ANSWER

Use dataclasses.fields() to iterate over all the fields, making a list of their names.

Use dataclasses.is_dataclass() to tell if a field is a nested dataclass. If so, recurse into it instead of adding its name to the list.

from dataclasses import fields, is_dataclass

def all_fields(c: type) -> list[str]:
    field_list = []
    for f in fields(c):
        if is_dataclass(f.type):
            field_list.extend(all_fields(f.type))
        else:
            field_list.append(f.name)
    return field_list