Source code for zfit.util.graph

#  Copyright (c) 2020 zfit

from typing import List

import tensorflow as tf

from zfit.util.temporary import TemporarilySet


[docs]def all_parents(op, current_obs=None): if current_obs is None: current_obs = set() ops = set(input_.op for input_ in op.inputs if input_.op not in current_obs) current_obs = current_obs.union(ops) return ops.union(*(all_parents(op, current_obs=current_obs) for op in ops))
[docs]def get_dependents_auto(tensor: tf.Tensor, candidates: List[tf.Tensor]) -> List[tf.Tensor]: """Return the nodes in `candidates` that `tensor` depends on. Args: tensor: candidates: """ try: dependent_ops = all_parents(tensor.op) except RuntimeError as error: raise ValueError("Tensor too deeply nested, recursion limit exceeded. In the future," "implementation will be different and any dependents can be found." "Currently, specify dependents explicitly if needed." "Orignal Error: {}".format(error)) dependent_candidates = [cand for cand in candidates if cand.op in dependent_ops] return dependent_candidates
[docs]class JIT: def _set_all(self, enable: bool = True): new_values = {k: enable for k in self._get_allowed()} def getter(): return self._get_allowed().copy() def setter(jit_types): self._update_allowed(jit_types) return TemporarilySet(getter=getter, setter=setter, value=new_values) def _set_default(self): from zfit import z new_values = z.zextension.FunctionWrapperRegistry._DEFAULT_DO_JIT_TYPES.copy() for key in self._get_allowed(): if key not in new_values: new_values[key] = new_values[key] # default dict will explicitly set the default value def getter(): return self._get_allowed().copy() def setter(jit_types): self._update_allowed(jit_types) return TemporarilySet(getter=getter, setter=setter, value=new_values) def _update_allowed(self, update_jit): from zfit import z z.zextension.FunctionWrapperRegistry.do_jit_types.update(update_jit) def _get_allowed(self): from zfit import z return z.zextension.FunctionWrapperRegistry.do_jit_types @property def experimental_is_eager(self): from ..settings import run return run.mode['graph']
jit = JIT() # singleton