Source code for deepgraph.connectors_and_selectors

# Copyright (C) 2017-2025 by
# Dominik Traxl <dominik.traxl@posteo.org>
# All rights reserved.
# BSD-3-Clause License.

import inspect
import warnings
import ast
import textwrap
from collections import Counter

import numpy as np

from deepgraph._find_selected_indices import _find_selected_indices
from deepgraph.utils import _flatten


[docs] def output_names(*names): """Decorator to specify the output variable names of a connector or selector function. Use this decorator to explicitly define the names of the output variables returned by your connector or selector functions. This is necessary when the source code of the function is not available at runtime. Parameters ---------- *names : str One or more positional string arguments representing the names of the output variables. Returns ------- callable A decorator that attaches the specified output names to the decorated function via the `_output_names` attribute. See also -------- deepgraph.DeepGraph.create_edges deepgraph.DeepGraph.create_edges_ft Examples -------- Specifying output variable names for a connector function: >>> import deepgraph as dg >>> >>> @dg.output_names("dx", "dt", "v"): >>> def velocity(x_s, x_t, time_s, time_t): >>> dx = x_t - x_s >>> dt = time_t - time_s >>> v = dx/dt >>> return dx, dt, v Specifying output variable names for a selector function: >>> import deepgraph as dg >>> >>> @dg.output_names("sources", "targets", "dx") >>> def distance_selector(x_s, x_t, sources, targets): >>> dx = x_t - x_s >>> sources = sources[dx <= 5] >>> targets = targets[dx <= 5] >>> return sources, targets, dx """ def decorator(func): func._output_names = names return func return decorator
class CreatorFunction: # dict to store relations stored_relations = {} # Connector attributes c_instances = [] c_input_features = [] c_input_rs = [] c_output_rs = [] # Selector attributes s_instances = [] s_input_features = [] s_input_rs = [] s_output_rs = [] def __init__(self, fct): assert callable(fct), "{} is not callable.".format(fct) # make function accessible via self.fct, give self.name self.fct = fct self.name = getattr(self.fct, '__name__', repr(self.fct)) # find all input arguments input_args = list(inspect.signature(fct).parameters.keys()) self.input_features = [x for x in input_args if x.endswith("_s") or x.endswith("_t")] self.input_rs = [ x for x in input_args if x not in self.input_features and not x == "sources" and not x == "targets" ] # find all output variables try: output = list(getattr(fct, "_output_names")) except AttributeError: output = self._extract_return_variables() self.output_rs = [x for x in output if x != "sources" and x != "targets"] self.output = output @classmethod def assertions(cls, v, r_dtype_dic): # self.input_features of self.c_instances & self.s_instances # must be in v.columns.values # set(cls.c_input_features).issubset(v.columns.values) # connectors and selectors must have exclusive output relations rs = cls.c_output_rs + cls.s_output_rs count_rs = Counter(rs) if not len(rs) == 0: msg = ( "There are common output relations in " "connectors and/or selectors. \n" "[(relation, number of occurences)]: \n {}" ) assert set(count_rs.values()) == {1}, msg.format([(r, nr) for r, nr in count_rs.items() if nr > 1]) # dtypes for relations given which are not in any output unused_dtypes = set(r_dtype_dic.keys()).difference(rs) if len(unused_dtypes) != 0: warnings.warn( "There are dtypes given by 'r_dtype_dic' for which there is no" " output variable(s): \n {}".format(list(unused_dtypes)), UserWarning, ) @classmethod def reset(cls, all_or_WS): if all_or_WS == "stored_relations": cls.stored_relations = {} elif all_or_WS == "all": cls.stored_relations = {} atrs = [ atr for atr in dir(cls) if not atr.startswith("__") and not atr == "stored_relations" and not callable(getattr(cls, atr)) ] for atr in atrs: setattr(cls, atr, []) @classmethod def flatten_variables(cls): atrs = [ atr for atr in dir(cls) if not atr.startswith("__") and not atr == "stored_relations" and not callable(getattr(cls, atr)) ] for atr in atrs: setattr(cls, atr, _flatten(cls.__dict__[atr])) def _extract_return_variables(self): try: source = inspect.getsource(self.fct) except OSError as e: msg = ( f"Unable to retrieve the source code of the function '{self.name}'.\n\n" f"To avoid this issue, decorate your function with @deepgraph.output_names so metadata is explicitly attached.\n" f"For more information, use: `help(deepgraph.output_names)`.\n" f"This ensures compatibility even in environments where source code is unavailable (e.g., interactive shells or compiled code).\n\n" ) raise OSError(msg) from e source = textwrap.dedent(source) tree = ast.parse(source) for node in ast.walk(tree): if isinstance(node, ast.Return): value = node.value if isinstance(value, ast.Tuple): if all(isinstance(elt, ast.Name) for elt in value.elts): return [elt.id for elt in value.elts] else: raise ValueError("Return tuple must contain only variable names.") elif isinstance(value, ast.Name): return [value.id] else: raise ValueError("Return must consist of variable names only.") raise ValueError("No return statement found in the function.") class Connector(CreatorFunction): def __init__(self, fct): super(Connector, self).__init__(fct) # append to superclass attributes self.c_instances.append(self) self.c_input_features.append(self.input_features) self.c_input_rs.append(self.input_rs) self.c_output_rs.append(self.output_rs) def map(self, vi, sources, targets, dt_unit, ft_feature): # input value dict ivdic = {} # input features for feature in self.input_features: if feature == "ft_feature_s": ivdic[feature] = vi[ft_feature[0]].values[sources] elif feature == "ft_feature_t": ivdic[feature] = vi[ft_feature[0]].values[targets] else: if feature.endswith("_s"): ivdic[feature] = vi[feature[:-2]].values[sources] elif feature.endswith("_t"): ivdic[feature] = vi[feature[:-2]].values[targets] # input relations for r in self.input_rs: try: ivdic[r] = CreatorFunction.stored_relations[r] except KeyError: msg = ( "{} requests {}, which has not yet " "been computed. Check the order of " "your connectors and selectors.".format(self.name, r) ) raise KeyError(msg) # evaluate output = self.fct(**ivdic) # store relations if not isinstance(output, tuple): output = (output,) for i, r in enumerate(self.output_rs): if r == "ft_r" and dt_unit is not None: CreatorFunction.stored_relations[r] = output[i] / np.timedelta64(1, dt_unit) else: CreatorFunction.stored_relations[r] = output[i] class Selector(CreatorFunction): def __init__(self, fct): super(Selector, self).__init__(fct) if self.name == "_ft_selector": self.input_rs.remove("ftt") # append to superclass variables self.s_instances.append(self) self.s_input_features.append(self.input_features) self.s_input_rs.append(self.input_rs) self.s_output_rs.append(self.output_rs) def select_and_store(self, vi, sources, targets, ft_feature, dt_unit): # input value dict ivdic = {} # input features for feature in self.input_features: if feature.endswith("_s"): ivdic[feature] = vi[feature[:-2]].values[sources] elif feature.endswith("_t"): ivdic[feature] = vi[feature[:-2]].values[targets] # input relations for r in self.input_rs: if r not in CreatorFunction.stored_relations: self.request_r(r, vi, sources, targets, dt_unit, ft_feature) try: ivdic[r] = CreatorFunction.stored_relations[r] except KeyError: msg = ( "{} requests {}, which has not yet " "been computed. Check the order of " "your connectors and selectors.".format(self.name, r) ) raise KeyError(msg) # input indices ivdic["sources"] = sources ivdic["targets"] = targets # for the fast track selector, we need the threshold value if self.name == "_ft_selector": ivdic["ftt"] = ft_feature[1] # select and return rs and new node indices output = self.fct(**ivdic) # output value dict ovdic = {} for i, name in enumerate(self.output): ovdic[name] = output[i] # assert that all output_rs have the same shape as the indices # PERFORMANCE for r in self.output_rs: assert len(ovdic[r]) == len(sources), "shape of {} has been modified in {}".format(r, self.name) # assert that new sources and target indices have same shape # PERFORMANCE assert len(ovdic["sources"]) == len( ovdic["targets"] ), "shape of reduced source and target indices must be the same." # store output rs of selectors in CreatorFunction.stored_relations for r in self.output_rs: CreatorFunction.stored_relations[r] = ovdic[r] # positional indices of selected pairs in the former indices if not len(ovdic["sources"]) == len(sources): index = _find_selected_indices(sources, targets, ovdic["sources"], ovdic["targets"]) else: index = np.arange(len(sources)) # shrink CreatorFunction.stored_relations by selected indices for r in CreatorFunction.stored_relations: CreatorFunction.stored_relations[r] = CreatorFunction.stored_relations[r][index] # return updated indices return ovdic["sources"], ovdic["targets"] @staticmethod def request_r(r, vi, sources, targets, dt_unit, ft_feature): # find the connector mapping to r, evaluate and store for connector in CreatorFunction.c_instances: if r in connector.output_rs: connector.map(vi, sources, targets, dt_unit, ft_feature)