import itertools
from typing import Iterable, Callable, Optional

import numdifftools
import tensorflow as tf

from ..settings import ztypes
from ..util.container import convert_to_container

[docs]def poly_complex(*args, real_x=False):
"""Complex polynomial with the last arg being x.

Args:
*args (tf.Tensor or equ.): Coefficients of the polynomial
real_x (bool): If True, x is assumed to be real.

Returns:
tf.Tensor:
"""
from .. import z

args = list(args)
x = args.pop()
if real_x is not None:
pow_func = tf.pow
else:
pow_func = z.nth_pow
return tf.add_n([coef * z.to_complex(pow_func(x, p)) for p, coef in enumerate(args)])

[docs]def interpolate(t, c):
"""Multilinear interpolation on a rectangular grid of arbitrary number of dimensions.

Args:
t (tf.Tensor): Grid (of rank N)
c (tf.Tensor): Tensor of coordinates for which the interpolation is performed

Returns:
tf.Tensor: 1D tensor of interpolated value
"""
rank = len(t.get_shape())
ind = tf.cast(tf.floor(c), tf.int32)
wts = []
for vertex in itertools.product([0, 1], repeat=rank):
ind2 = ind + tf.constant(vertex, dtype=tf.int32)
weight = tf.reduce_prod(input_tensor=1. - tf.abs(c - tf.cast(ind2, dtype=ztypes.float)), axis=1)
wt = tf.gather_nd(t2, ind2 + 1)
wts += [weight * wt]
interp = tf.reduce_sum(input_tensor=tf.stack(wts), axis=0)
return interp

[docs]def numerical_gradient(func: Callable, params: Iterable["zfit.Parameter"]) -> tf.Tensor:
"""Calculate numerically the gradients of func() with respect to params.

Args:
func (Callable): Function without arguments that depends on params
params (ZfitParameter): Parameters that func implicitly depends on and with respect to which the
derivatives will be taken.

Returns:
tf.Tensor: gradients
"""
params = convert_to_container(params)

def wrapped_func(param_values):
for param, value in zip(params, param_values):
param.assign(value)
return func().numpy()

param_vals = tf.stack(params)
original_vals = [param.read_value() for param in params]
Tout=tf.float64)
for param, val in zip(params, original_vals):
param.set_value(val)

[docs]def numerical_value_gradients(func: Callable, params: Iterable["zfit.Parameter"]) -> [tf.Tensor, tf.Tensor]:
"""Calculate numerically the gradients of func() with respect to params, also returns the value of func().

Args:
func (Callable): Function without arguments that depends on params
params (ZfitParameter): Parameters that func implicitly depends on and with respect to which the
derivatives will be taken.

Returns:
tuple(tf.Tensor, tf.Tensor): value, gradient
"""

[docs]def numerical_hessian(func: Callable, params: Iterable["zfit.Parameter"], hessian=None) -> tf.Tensor:
"""Calculate numerically the hessian matrix of func with respect to params.

Args:
func (Callable): Function without arguments that depends on params
params (ZfitParameter): Parameters that func implicitly depends on and with respect to which the
derivatives will be taken.

Returns:
tf.Tensor: hessian matrix
"""
params = convert_to_container(params)

def wrapped_func(param_values):
for param, value in zip(params, param_values):
param.assign(value)
return func().numpy()

param_vals = tf.stack(params)
original_vals = [param.read_value() for param in params]

if hessian == 'diag':
hesse_func = numdifftools.Hessdiag(wrapped_func,
# TODO: maybe add step to remove numerical problems?
# step=1e-4
)
else:
hesse_func = numdifftools.Hessian(wrapped_func,
# base_step=1e-4
)
computed_hessian = tf.py_function(hesse_func, inp=[param_vals],
Tout=tf.float64)
n_params = param_vals.shape[0]
if hessian == 'diag':
computed_hessian.set_shape((n_params,))
else:
computed_hessian.set_shape((n_params, n_params))

for param, val in zip(params, original_vals):
param.set_value(val)
return computed_hessian

hessian: Optional[str] = None) -> [tf.Tensor, tf.Tensor, tf.Tensor]:
"""Calculate numerically the gradients and hessian matrix of func() wrt params; also return func().

Args:
func (Callable): Function without arguments that depends on params
params (ZfitParameter): Parameters that func implicitly depends on and with respect to which the
derivatives will be taken.

Returns:
tuple(tf.Tensor, tf.Tensor, tf.Tensor): value, gradient and hessian matrix
"""
hessian = numerical_hessian(func, params, hessian=hessian)

[docs]def autodiff_gradient(func: Callable, params: Iterable["zfit.Parameter"]) -> tf.Tensor:
"""Calculate using autodiff the gradients of func() wrt params.

Automatic differentiation (autodiff) is a way of retreiving the derivative of x wrt y. It works by consecutively
applying the chain rule. All that is needed is that every operation knows its own derivative.
TensorFlow implements this and anything using tf.* operations only can use this technique.

Args:
func (Callable): Function without arguments that depends on params
params (ZfitParameter): Parameters that func implicitly depends on and with respect to which the
derivatives will be taken.

Returns:
tf.Tensor: gradient
"""

[docs]def autodiff_value_gradients(func: Callable, params: Iterable["zfit.Parameter"]) -> [tf.Tensor, tf.Tensor]:
"""Calculate using autodiff the gradients of func() wrt params; also return func().

Automatic differentiation (autodiff) is a way of retreiving the derivative of x wrt y. It works by consecutively
applying the chain rule. All that is needed is that every operation knows its own derivative.
TensorFlow implements this and anything using tf.* operations only can use this technique.

Args:
func (Callable): Function without arguments that depends on params
params (ZfitParameter): Parameters that func implicitly depends on and with respect to which the
derivatives will be taken.

Returns:
tuple(tf.Tensor, tf.Tensor): value and gradient
"""
with tf.GradientTape(persistent=False,  # needs to be persistent for a call from hessian.
watch_accessed_variables=False) as tape:
tape.watch(params)
value = func()

[docs]def autodiff_hessian(func: Callable, params: Iterable["zfit.Parameter"], hessian=None) -> tf.Tensor:
"""Calculate using autodiff the hessian matrix of func() wrt params.

Automatic differentiation (autodiff) is a way of retrieving the derivative of x wrt y. It works by consecutively
applying the chain rule. All that is needed is that every operation knows its own derivative.
TensorFlow implements this and anything using tf.* operations only can use this technique.

Args:
func (Callable): Function without arguments that depends on params
params (ZfitParameter): Parameters that func implicitly depends on and with respect to which the
derivatives will be taken.

Returns:
tf.Tensor: hessian matrix
"""

[docs]def automatic_value_gradients_hessian(func: Callable = None, params: Iterable["zfit.Parameter"] = None,
hessian=None) -> [tf.Tensor, tf.Tensor, tf.Tensor]:
"""Calculate using autodiff the gradients and hessian matrix of func() wrt params; also return func().

Automatic differentiation (autodiff) is a way of retreiving the derivative of x wrt y. It works by consecutively
applying the chain rule. All that is needed is that every operation knows its own derivative.
TensorFlow implements this and anything using tf.* operations only can use this technique.

Args:
func (Callable): Function without arguments that depends on params
params (ZfitParameter): Parameters that func implicitly depends on and with respect to which the
derivatives will be taken.

Returns:
tuple(tf.Tensor, tf.Tensor, tf.Tensor): value, gradient and hessian matrix
"""
if params is None:
raise ValueError("Parameters have to be specified, are currently None.")
if func is None and value_grad_func is None:
ValueError("Either func or value_grad_func has to be specified.")

from .. import z
persistant = hessian == 'diag' or tf.executing_eagerly()  # currently needed, TODO: can we better parallelize that?
tape.watch(params)
else:
if hessian != 'diag':