File size: 2,434 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from collections import defaultdict
from typing import Optional
from modules.errors import log


def patch(key, obj, field, replacement, add_if_not_exists:bool = False):
    """Replaces a function in a module or a class.
    Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
    If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
    Arguments:
        key: identifying information for who is doing the replacement. You can use __name__.
        obj: the module or the class
        field: name of the function as a string
        replacement: the new function
    Returns:
        the original function
    """
    patch_key = (obj, field)
    if patch_key in originals[key]:
        log.error(f"Patch already applied: field={field}")
    if not hasattr(obj, field) and not add_if_not_exists:
        log.error(f"Patch no attribute: type={type(obj)} name='{type.__name__}' fiel'{field}'")
        return None
    original_func = getattr(obj, field, None)
    originals[key][patch_key] = original_func
    setattr(obj, field, replacement)
    return original_func


def undo(key, obj, field):
    """Undoes the peplacement by the patch().
    If the function is not replaced, raises an exception.
    Arguments:
        key: identifying information for who is doing the replacement. You can use __name__.
        obj: the module or the class
        field: name of the function as a string
    Returns:
        Always None
    """
    patch_key = (obj, field)
    if patch_key not in originals[key]:
        log.error(f"Patch no patch to undo: field={field}")
        return
    original_func = originals[key].pop(patch_key)
    if original_func is None:
        delattr(obj, field)
    setattr(obj, field, original_func)
    return None


def original(key, obj, field):
    """Returns the original function for the patch created by the patch() function"""
    patch_key = (obj, field)
    return originals[key].get(patch_key, None)


def patch_method(cls, key:Optional[str]=None):
    def decorator(func):
        patch(func.__module__ if key is None else key, cls, func.__name__, func)
    return decorator


def add_method(cls, key:Optional[str]=None):
    def decorator(func):
        patch(func.__module__ if key is None else key, cls, func.__name__, func, True)
    return decorator


originals = defaultdict(dict)