Source code for zfit.models.functor

"""
Functors are functions that take typically one or more other PDF. Prominent examples are a sum, convolution etc.

A FunctorBase class is provided to make handling the models easier.

Their implementation is often non-trivial.
"""
#  Copyright (c) 2020 zfit

from collections import OrderedDict
from typing import List, Optional

import tensorflow as tf

from ..core.basepdf import BasePDF
from ..core.coordinates import convert_to_obs_str
from ..core.interfaces import ZfitPDF, ZfitModel, ZfitData
from ..core.parameter import convert_to_parameter
from ..core.space import supports
from ..models.basefunctor import FunctorMixin, extract_daughter_input_obs
from ..settings import ztypes, run
from ..util import ztyping
from ..util.container import convert_to_container
from ..util.exception import (ModelIncompatibleError, ObsIncompatibleError, NormRangeUnderdefinedError,
                              AnalyticIntegralNotImplementedError, SpecificFunctionNotImplementedError)
from ..util.warnings import warn_advanced_feature
from ..z.random import counts_multinomial


[docs]class BaseFunctor(FunctorMixin, BasePDF): def __init__(self, pdfs, name="BaseFunctor", **kwargs): self.pdfs = convert_to_container(pdfs) super().__init__(models=self.pdfs, name=name, **kwargs) self._set_norm_range_from_daugthers() # TODO: remove? # def _get_component_norm_range(self): # return self._component_norm_range_holder # TODO: remove? # def _set_component_norm_range(self, norm_range: ztyping.LimitsTypeInput): # norm_range = self._check_input_norm_range(norm_range=norm_range) # # # TODO: remove completely? cleanup functor, norm_range? # # if not norm_range.has_limits: # # if self._get_component_norm_range() is None: # # raise RuntimeError("Cannot use `False` as `norm_range` without previously setting the " # # "`component_norm_range`.") # # def setter(value): # self._component_norm_range_holder = value # # return TemporarilySet(value=norm_range, setter=setter, getter=self._get_component_norm_range) def _set_norm_range_from_daugthers(self): norm_range = super().norm_range if not norm_range.limits_are_set: norm_range = extract_daughter_input_obs(obs=norm_range, spaces=[model.space for model in self.models]) if not norm_range.limits_are_set: raise NormRangeUnderdefinedError( f"Daughter pdfs {self.pdfs} do not agree on a `norm_range` and/or no `norm_range`" "has been explicitly set.") self.set_norm_range(norm_range) # TODO: remove below? # # def _infer_space_from_daughters(self): # space = set(model.space for model in self.models) # obs = set(norm_range.obs for norm_range in space) # if len(space) == 1: # return space.pop() # elif len(obs) > 1: # TODO(Mayou36, #77): different obs? # return None # else: # return False @property def pdfs_extended(self): return [pdf.is_extended for pdf in self.pdfs] @property def _models(self) -> List[ZfitModel]: return self.pdfs
[docs]class SumPDF(BaseFunctor): def __init__(self, pdfs: List[ZfitPDF], fracs: Optional[ztyping.ParamTypeInput] = None, obs: ztyping.ObsTypeInput = None, name: str = "SumPDF"): """Create the sum of the `pdfs` with `fracs` as coefficients or the yields, if extended pdfs are given. If *all* pdfs are extended, the fracs is optional and the (normalized) yields will be used as fracs. If fracs is given, this will be used as the fractions, regardless of whether the pdfs have a yield or not. The parameters of the SumPDF are the fractions that are used to multiply the output of each daughter pdf. They can be accessed with `pdf.params` and have names f"frac_{i}" with i starting from 0 and going to the number of pdfs given. To get the component outputs of this pdf, e.g. to plot it, use `pdf.params.values()` to iterate through the fracs and `pdfs` to get the pdfs. For example .. code-block:: python for pdf, frac in zip(sumpdf.pdfs, sumpdf.params.values()): frac_integral = pdf.integrate(...) * frac Args: pdfs (pdf): The pdfs to be added. fracs (iterable): coefficients for the linear combination of the pdfs. Optional if *all* pdfs are extended. - len(frac) == len(basic) - 1 results in the interpretation of a non-extended pdf. The last coefficient will equal to 1 - sum(frac) - len(frac) == len(pdf): the fracs will be used as is and no normalization attempt is taken. name (str): Raises ModelIncompatibleError: if """ # Check user input self._fracs = None pdfs = convert_to_container(pdfs) self.pdfs = pdfs if len(pdfs) < 2: raise ValueError(f"Cannot build a sum of a single pdf {pdfs}") common_obs = obs if obs is not None else pdfs[0].obs common_obs = convert_to_obs_str(common_obs) if not all(frozenset(pdf.obs) == frozenset(common_obs) for pdf in pdfs): raise ObsIncompatibleError("Currently, sums are only supported in the same observables") # check if all extended are_extended = [pdf.is_extended for pdf in pdfs] all_extended = all(are_extended) no_extended = not any(are_extended) fracs = convert_to_container(fracs) if fracs: # not None or empty list fracs = [convert_to_parameter(frac) for frac in fracs] elif not all_extended: raise ModelIncompatibleError(f"Not all pdf {pdfs} are extended and no fracs {fracs} are provided.") if not no_extended and fracs: warn_advanced_feature(f"This SumPDF is built with fracs {fracs} and (some or all) pdf extended {pdfs}." f" This will ignore the yields of the already extended pdfs and the result will" f" be a not extended SumPDF.", identifier='sum_extended_frac') # catch if args don't fit known case if fracs: if not len(fracs) in (len(pdfs), len(pdfs) - 1): raise ModelIncompatibleError(f"If all PDFs are not extended {pdfs}, the fracs {fracs} have to be of" f" the same length as pdf or one less.") # create fracs if one is missing elif len(fracs) == len(pdfs) - 1: remaining_frac_func = lambda: tf.constant(1., dtype=ztypes.float) - tf.add_n(fracs) remaining_frac = convert_to_parameter(remaining_frac_func, dependents=fracs) if run.numeric_checks: tf.debugging.assert_non_negative(remaining_frac, tf.constant(0., dtype=ztypes.float), f"The remaining fraction is negative, the sum of fracs is > 0. Fracs: {fracs}") # check fractions # IMPORTANT! Otherwise, recursion due to namespace capture in the lambda fracs_cleaned = fracs + [remaining_frac] else: fracs_cleaned = fracs param_fracs = fracs_cleaned # for the extended case, take the yields, normalize them, in case no fracs are given. if all_extended and not fracs: yields = [pdf.get_yield() for pdf in pdfs] def sum_yields_func(): return tf.reduce_sum( input_tensor=[tf.convert_to_tensor(value=y, dtype_hint=ztypes.float) for y in yields]) sum_yields = convert_to_parameter(sum_yields_func, dependents=yields) yield_fracs = [convert_to_parameter(lambda yield_=yield_: yield_ / sum_yields, dependents=yield_) for yield_ in yields] fracs_cleaned = None param_fracs = yield_fracs self.pdfs = pdfs self._fracs = param_fracs self._original_fracs = fracs_cleaned params = OrderedDict() for i, frac in enumerate(param_fracs): params['frac_{}'.format(i)] = frac super().__init__(pdfs=pdfs, obs=obs, params=params, name=name) if all_extended and not fracs_cleaned: self._set_yield_inplace(sum_yields) # self.set_yield(sum_yields) # TODO(SUM): why not the public method below? @property def fracs(self): return self._fracs def _apply_yield(self, value: float, norm_range: ztyping.LimitsType, log: bool): if all(self.pdfs_extended): return value else: return super()._apply_yield(value=value, norm_range=norm_range, log=log) def _unnormalized_pdf(self, x): # TODO: cleanup component ranges # norm_range = self._get_component_norm_range() # return self._pdf(x=x, norm_range=norm_range) pdfs = self.pdfs fracs = self.params.values() prob = tf.math.accumulate_n([pdf.pdf(x) * frac for pdf, frac in zip(pdfs, fracs)]) return prob # TODO(SUM): remove the below? Not needed anymore? # def _set_yield(self, value: Union[Parameter, None]): # # TODO: what happens now with the daughters? # if all(self.pdfs_extended) and self.is_extended and value is not None: # to be able to set the yield in the # raise AlreadyExtendedPDFError("Cannot set the yield of a PDF with extended daughters.") # # TODO(SUM): why was that needed? # # elif all(self.pdfs_extended) and self.is_extended and value is None: # not extended anymore # # reciprocal_yield = convert_to_parameter(lambda: tf.math.reciprocal(self.get_yield()), # # dependents=self.get_yield()) # # self._maybe_extended_fracs = [reciprocal_yield] * len(self._maybe_extended_fracs) # else: # super()._set_yield(value=value) @supports(multiple_limits=True) def _integrate(self, limits, norm_range): pdfs = self.pdfs fracs = self.fracs # TODO(SUM): why was this needed? # assert norm_range not in (None, False), "Bug, who requested an unnormalized integral?" integrals = [frac * pdf.integrate(limits=limits) # do NOT propagate the norm_range! for pdf, frac in zip(pdfs, fracs)] # TODO(SUM): change the below? broadcast integrals? # integral = tf.reduce_sum(input_tensor=integrals, axis=0) integral = tf.math.accumulate_n(integrals) return integral @supports(multiple_limits=True) def _analytic_integrate(self, limits, norm_range): pdfs = self.pdfs fracs = self.fracs try: integrals = [frac * pdf.analytic_integrate(limits=limits) # do NOT propagate the norm_range! for pdf, frac in zip(pdfs, fracs)] except AnalyticIntegralNotImplementedError as error: raise AnalyticIntegralNotImplementedError( f"analytic_integrate of pdf {self.name} is not implemented in this" f" SumPDF, as at least one sub-pdf does not implement it.") from error # TODO(SUM): change the below? broadcast integrals? # integral = tf.reduce_sum(input_tensor=integrals) integral = tf.math.accumulate_n(integrals) return integral @supports(multiple_limits=True) def _partial_integrate(self, x, limits, norm_range): pdfs = self.pdfs fracs = self.fracs partial_integral = [pdf.partial_integrate(x=x, limits=limits) * frac # do NOT propagate the norm_range! for pdf, frac in zip(pdfs, fracs)] partial_integral = tf.math.accumulate_n(partial_integral) return partial_integral @supports(multiple_limits=True) def _partial_analytic_integrate(self, x, limits, norm_range): pdfs = self.pdfs fracs = self.fracs try: partial_integral = [pdf.partial_analytic_integrate(x=x, limits=limits) * frac # do NOT propagate the norm_range! for pdf, frac in zip(pdfs, fracs)] except AnalyticIntegralNotImplementedError as error: raise AnalyticIntegralNotImplementedError( "partial_analytic_integrate of pdf {name} is not implemented in this" " SumPDF, as at least one sub-pdf does not implement it.") from error partial_integral = tf.math.accumulate_n(partial_integral) return partial_integral @supports(multiple_limits=True) def _sample(self, n, limits): if (isinstance(n, str)): n = [n] * len(self.pdfs) else: n = tf.unstack(counts_multinomial(total_count=n, probs=self.fracs), axis=0) samples = [] for pdf, n_sample in zip(self.pdfs, n): sub_sample = pdf.sample(n=n_sample, limits=limits) if isinstance(sub_sample, ZfitData): sub_sample = sub_sample.value() samples.append(sub_sample) sample = tf.concat(samples, axis=0) return sample
[docs]class ProductPDF(BaseFunctor): # TODO: compose of smaller Product PDF by disasembling components subsets of obs def __init__(self, pdfs: List[ZfitPDF], obs: ztyping.ObsTypeInput = None, name="ProductPDF"): super().__init__(pdfs=pdfs, obs=obs, name=name) def _unnormalized_pdf(self, x: ztyping.XType): return tf.math.reduce_prod([pdf.unnormalized_pdf(x) for pdf in self.pdfs], axis=0) def _pdf(self, x, norm_range): if all(not dep for dep in self._model_same_obs): probs = [pdf.pdf(x=x) for pdf in self.pdfs] return tf.reduce_prod(input_tensor=probs, axis=0) else: raise SpecificFunctionNotImplementedError