Update a class method default arg if the class it call by another package

55 Views Asked by At

I need to update a kwarg for a class method which is called by another package.

I interact with PyJWT encode much later this calls cryptography load_pem_private_key (https://cryptography.io/en/latest/hazmat/primitives/asymmetric/serialization/#cryptography.hazmat.primitives.serialization.load_pem_public_key)

I need to update the default for load_pem_private_key unsafe_skip_rsa_key_validation to True. How could this be done.

I have looked into dependency injection but looks like I would still need the instantiate the class to inject my new method.

1

There are 1 best solutions below

0
jsbueno On

This can be done with "monkey-patching" and the use of functools.partial -

Monkey patching is the term used for the act of replacing existing functions or other values in 3rdy party code for code under our control in runtime.

functools.partial, on the other hand, returns a function-like object which features pre-annotated arguments for one or more of the function parameters, which effectively work as "new defaults"

Since you don't need to revert to the old value, just an attribution (with the = operator) is enough.

You just have to take care on how the imports are done, so its proper monkeypatched. If you import the function to your current module, and it is used in another one, the patching would have no effect: you just replaced the function in your current module.

So it is important to replace the function listed in the cyrptography.hazmat.primitives.serialization package - and if the module which calls "add...key" imports it with from crypt...serialization import add...key replace it in the target module as well. By doing the two replacements, instead of just one, you won't depend on the import order of modules in your system - if it is already imported, the replacement is made there - if the import would take place later, you have replaced the original in the cryptography package:

from functools import serial

from from cryptography.hazmat.primitives import serialization

# import target_module  # <- the module that will call the function

new_func = partial(serialization.load_pem_private_key, unsafe_skip_rsa_key_validation=True)
serialization.load_pem_private_key = new_func
# target_module.load_pem_private_key = new_func

(if target module import does not import the function by name, but rather does call it with the module names, with dotted notation, like serialization.load_pem_private_key(...), there is no need to patch it there - they will always use the replaced version).