How to properly inherit from dataclass with base instance re-usage

40 Views Asked by At

Consider the following:

@dataclass
class Base:
  prop1: str
  prop2: str


@dataclass
class Derived1(Base):
  isValid: bool = self.prop2.casefold() == 'valid'

@dataclass
class Derived2(Base):
  isOpen: bool = self.prop1.casefold() == 'open'
  isShared :bool


In this particular example, baseclass it just 2 props, but imagine it is 773 props just for argument sake. Something, somewhere, you dont care what and where, returns you an instance of Base. Now, you may want to promote it to Derived1, keeping all the existing values. Normally, as far as I know, if I want to an instance of a Derived1 or Derived2 class id have to initialize this instance and then manually write assigned in Derived's __init__ or __post_init__ for all the 773 props. It is madness. After some re-search I found a great suggestion that solves this. Instead I use normal class and update props from existing base instance via __dict__ trick:

class Derived(Base):
    isOpen: bool = False

    def __init__(self, base: Base):
        self.__dict__.update(base.__dict__)
        self.isOpen = (str(base.currStatus).casefold() == 'open')

While this seems to be working just fine I keep getting a hard error from pylint that there needs to be super. And this is what I'm confused about - do I need need a super here? The issue is that with super I'm back to square one where I have to re-init all 700+ props manually... which is what I wanted to avoid in a first place.

This whole thing seems incredibly silly - this is python we are talking about, I easily can see it having this problem when it comes to classes (which some other languages solved 30 years ago).

1

There are 1 best solutions below

0
OysterShucker On

Use asdict and include the requirements for dictionary unpacking. The below mixin already has this ability, and a little more.

# mixins.py
from __future__  import annotations
from dataclasses import asdict
from typing      import Any, Iterable
import json


""" Dataclass_mi: dataclass mixin for dictionary behavior
  supports:
    *) .keys(), .values(), .items()
    *) dictionary unpacking
    *) pretty-print self as json
    *) return self as dict
    *) reassign any/all fields in one call
      +) optional __post_init__ call after reassignment
"""
class Dataclass_mi:
  @property
  def asdict(self) -> dict:
    return asdict(self)
    
  def items(self) -> Iterable:
    return asdict(self).items()
    
  def values(self) -> Iterable:
    return asdict(self).values()
      
  # required for dictionary unpacking
  def keys(self) -> Iterable:
    return asdict(self).keys()
  
  # required for dictionary unpacking  
  # an error will be raised if this key does not exist
  def __getitem__(self, key:str) -> Any:
    return getattr(self, key)
  
  # pretty-print self as json
  def __str__(self) -> str:
    return json.dumps(asdict(self), indent=4, default=str)
  
  # reassign fields from kwargs and optionally call __post_init__
  def __call__(self, initvars:dict|None=None, **kwargs) -> Dataclass_mi:
    # if key exists, set value. otherwise, ignore.
    for key, value in kwargs.items():
      if hasattr(self, key):
        setattr(self, key, value)
        
    # if you want __post_init__ to be called, but it doesn't accept arguments -
    # pass initvars as an empty dict
    if initvars is not None: 
      self.__post_init__(**initvars)
        
    # let this be an inline method        
    return self
classes
from dataclasses import dataclass
from mixins import Dataclass_mi
   
@dataclass
class Base(Dataclass_mi):
  prop1: str
  prop2: str


@dataclass
class Derived1(Base):
  isValid: bool = False

  def __post_init__(self) -> None:
    self.isValid = self.prop2.lower() == 'valid'


@dataclass
class Derived2(Base):
  isShared: bool 
  isOpen  : bool = False
  
  def __post_init__(self) -> None:
    self.isOpen = self.prop1.lower() == 'open'


usage
base = Base('open', 'valid')

# init
derived = Derived2(**base, isShared=False)

print(derived)

# customize
derived({}, prop1='closed')

print(derived)
output
{
    "prop1": "open",
    "prop2": "valid",
    "isShared": false,
    "isOpen": true
}
{
    "prop1": "closed",
    "prop2": "valid",
    "isShared": false,
    "isOpen": false
}