import ast import math import random import operator as op # Hack: string type that is always equal in not equal comparisons class AnyType(str): def __ne__(self, __value: object) -> bool: return False # Our any instance wants to be a wildcard string any = AnyType("*") operators = { ast.Add: op.add, ast.Sub: op.sub, ast.Mult: op.mul, ast.Div: op.truediv, ast.FloorDiv: op.floordiv, ast.Pow: op.pow, ast.BitXor: op.xor, ast.USub: op.neg, ast.Mod: op.mod, ast.BitAnd: op.and_, ast.BitOr: op.or_, ast.Invert: op.invert, ast.And: lambda a, b: 1 if a and b else 0, ast.Or: lambda a, b: 1 if a or b else 0, ast.Not: lambda a: 0 if a else 1, ast.RShift: op.rshift, ast.LShift: op.lshift } # TODO: restructure args to provide more info, generate hint based on args to save duplication functions = { "round": { "args": (1, 2), "call": lambda a, b = None: round(a, b), "hint": "number, dp? = 0" }, "ceil": { "args": (1, 1), "call": lambda a: math.ceil(a), "hint": "number" }, "floor": { "args": (1, 1), "call": lambda a: math.floor(a), "hint": "number" }, "min": { "args": (2, None), "call": lambda *args: min(*args), "hint": "...numbers" }, "max": { "args": (2, None), "call": lambda *args: max(*args), "hint": "...numbers" }, "randomint": { "args": (2, 2), "call": lambda a, b: random.randint(a, b), "hint": "min, max" }, "randomchoice": { "args": (2, None), "call": lambda *args: random.choice(args), "hint": "...numbers" }, "sqrt": { "args": (1, 1), "call": lambda a: math.sqrt(a), "hint": "number" }, "int": { "args": (1, 1), "call": lambda a = None: int(a), "hint": "number" }, "iif": { "args": (3, 3), "call": lambda a, b, c = None: b if a else c, "hint": "value, truepart, falsepart" }, } autocompleteWords = list({ "text": x, "value": f"{x}()", "showValue": False, "hint": f"{functions[x]['hint']}", "caretOffset": -1 } for x in functions.keys()) class MathExpression: @classmethod def INPUT_TYPES(cls): return { "required": { "expression": ("STRING", {"multiline": True, "dynamicPrompts": False, "pysssss.autocomplete": { "words": autocompleteWords, "separator": "" }}), }, "optional": { "a": (any, ), "b": (any,), "c": (any, ), }, "hidden": {"extra_pnginfo": "EXTRA_PNGINFO", "prompt": "PROMPT"}, } RETURN_TYPES = ("INT", "FLOAT", ) FUNCTION = "evaluate" CATEGORY = "utils" OUTPUT_NODE = True @classmethod def IS_CHANGED(s, expression, **kwargs): if "random" in expression: return float("nan") return expression def get_widget_value(self, extra_pnginfo, prompt, node_name, widget_name): workflow = extra_pnginfo["workflow"] if "workflow" in extra_pnginfo else { "nodes": [] } node_id = None for node in workflow["nodes"]: name = node["type"] if "properties" in node: if "Node name for S&R" in node["properties"]: name = node["properties"]["Node name for S&R"] if name == node_name: node_id = node["id"] break if "title" in node: name = node["title"] if name == node_name: node_id = node["id"] break if node_id is not None: values = prompt[str(node_id)] if "inputs" in values: if widget_name in values["inputs"]: value = values["inputs"][widget_name] if isinstance(value, list): raise ValueError("Converted widgets are not supported via named reference, use the inputs instead.") return value raise NameError(f"Widget not found: {node_name}.{widget_name}") raise NameError(f"Node not found: {node_name}.{widget_name}") def get_size(self, target, property): if isinstance(target, dict) and "samples" in target: # Latent if property == "width": return target["samples"].shape[3] * 8 return target["samples"].shape[2] * 8 else: # Image if property == "width": return target.shape[2] return target.shape[1] def evaluate(self, expression, prompt, extra_pnginfo={}, a=None, b=None, c=None): expression = expression.replace('\n', ' ').replace('\r', '') node = ast.parse(expression, mode='eval').body lookup = {"a": a, "b": b, "c": c} def eval_op(node, l, r): l = eval_expr(l) r = eval_expr(r) l = l if isinstance(l, int) else float(l) r = r if isinstance(r, int) else float(r) return operators[type(node.op)](l, r) def eval_expr(node): if isinstance(node, ast.Constant) or isinstance(node, ast.Num): return node.n elif isinstance(node, ast.BinOp): return eval_op(node, node.left, node.right) elif isinstance(node, ast.BoolOp): return eval_op(node, node.values[0], node.values[1]) elif isinstance(node, ast.UnaryOp): return operators[type(node.op)](eval_expr(node.operand)) elif isinstance(node, ast.Attribute): if node.value.id in lookup: if node.attr == "width" or node.attr == "height": return self.get_size(lookup[node.value.id], node.attr) return self.get_widget_value(extra_pnginfo, prompt, node.value.id, node.attr) elif isinstance(node, ast.Name): if node.id in lookup: val = lookup[node.id] if isinstance(val, (int, float, complex)): return val else: raise TypeError( f"Compex types (LATENT/IMAGE) need to reference their width/height, e.g. {node.id}.width") raise NameError(f"Name not found: {node.id}") elif isinstance(node, ast.Call): if node.func.id in functions: fn = functions[node.func.id] l = len(node.args) if l < fn["args"][0] or (fn["args"][1] is not None and l > fn["args"][1]): if fn["args"][1] is None: toErr = " or more" else: toErr = f" to {fn['args'][1]}" raise SyntaxError( f"Invalid function call: {node.func.id} requires {fn['args'][0]}{toErr} arguments") args = [] for arg in node.args: args.append(eval_expr(arg)) return fn["call"](*args) raise NameError(f"Invalid function call: {node.func.id}") elif isinstance(node, ast.Compare): l = eval_expr(node.left) r = eval_expr(node.comparators[0]) if isinstance(node.ops[0], ast.Eq): return 1 if l == r else 0 if isinstance(node.ops[0], ast.NotEq): return 1 if l != r else 0 if isinstance(node.ops[0], ast.Gt): return 1 if l > r else 0 if isinstance(node.ops[0], ast.GtE): return 1 if l >= r else 0 if isinstance(node.ops[0], ast.Lt): return 1 if l < r else 0 if isinstance(node.ops[0], ast.LtE): return 1 if l <= r else 0 raise NotImplementedError( "Operator " + node.ops[0].__class__.__name__ + " not supported.") else: raise TypeError(node) r = eval_expr(node) return {"ui": {"value": [r]}, "result": (int(r), float(r),)} NODE_CLASS_MAPPINGS = { "MathExpression|pysssss": MathExpression, } NODE_DISPLAY_NAME_MAPPINGS = { "MathExpression|pysssss": "Math Expression 🐍", }