import jax
import jax.numpy as jnp

import numpy as np

import pyhf
# use jax backend of pyhf
pyhf.set_backend(pyhf.tensor.jax_backend())

import smooth
import smooth.infer
nBg  = 8000
nSig = 300

background = np.random.normal(40, 10, nBg)
signal = np.random.normal(50, 5, nSig)

def generate_data():
    return signal, background
import matplotlib.pyplot as plt

bins = np.linspace(0, 80, 40)
plt.figure(dpi=120)
ax = plt.gca()
plt.hist([background, signal], bins=bins, stacked=True, label=["background", "signal"])
ax.set(ylabel='frequency',xlabel='x')
plt.legend();
def preprocess(data_generator):
    
    def counts(cut_param):
        s, b = data_generator()
        
        s_counts = smooth.cut(s,'>',cut_param).sum()
        b_counts = smooth.cut(b,'>',cut_param).sum()
        
        return jnp.array([s_counts]), jnp.array([b_counts])
    
    return counts
def simple_histosys(data_with_cuts, uncert):
    """
    Makes a histosys model with up/down yields of +- bkg*(uncert/2).
    """

    def from_spec(yields):

        s, b = yields
        bup, bdown = b * (1 + (uncert / 2)), b * (1 - (uncert / 2))

        spec = {
            "channels": [
                {
                    "name": "smoothcut",
                    "samples": [
                        {
                            "name": "signal",
                            "data": s,
                            "modifiers": [
                                {"name": "mu", "type": "normfactor", "data": None}
                            ],
                        },
                        {
                            "name": "bkg",
                            "data": b,
                            "modifiers": [
                                {
                                    "name": "artificial_histosys",
                                    "type": "histosys",
                                    "data": {"lo_data": bdown, "hi_data": bup,},
                                }
                            ],
                        },
                    ],
                },
            ],
        }

        return pyhf.Model(spec)

    def model_maker(cut_pars):
        s, b = data_with_cuts(cut_pars)

        # make statistical model with pyhf
        m = from_spec([s, b])

        nompars = m.config.suggested_init()
        bonlypars = jnp.asarray([x for x in nompars])
        bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
        return m, bonlypars

    return model_maker
data_with_cuts = preprocess(generate_data)
model_maker    = simple_histosys(data_with_cuts, uncert=0.05)
loss           = smooth.infer.expected_pvalue_upper_limit(model_maker,
                                                          solver_kwargs=dict(pdf_transform=True))
# loss2 = infer.upper_lim(loss)

jax.value_and_grad(loss)(1.,40.)
(DeviceArray(0.00668244, dtype=float64),
 DeviceArray(-0.0461769, dtype=float64))
loss2(40.)
Traced<ShapedArray(float64[]):JaxprTrace(level=0/0)>
DeviceArray(0., dtype=float64)
l=[]
mus = np.linspace(0,1,100)
[l.append(loss2(40.)(mu)) for mu in mus]

import matplotlib.pyplot as plt


plt.figure(dpi=120)
plt.plot(mus,l)
min(l)
DeviceArray(0.00097437, dtype=float64)
def scn(c):
    return loss(c,1.)

cuts_smooth = np.linspace(20, 70, 100)
significances_smooth = np.asarray([scn(cut) for cut in cuts_smooth])
def significance(S, B):
    return jnp.sqrt(2*((S+B)*jnp.log(1+S/B)-S))

def objective(data_with_cuts):
    
    def significance(S, B):
        return jnp.sqrt(2*((S+B)*jnp.log(1+S/B)-S))
    
    def get_significance_smooth(cut_pars):
        '''
        Our loss function! Takes in cut parameters, returns significance
        '''
        s, b = data_with_cuts(cut_pars)
        return significance(s[0], b[0])
    
    return get_significance_smooth

get = objective(preprocess(generate_data))

cuts = np.linspace(20, 70, 500)
significances = [1 - pyhf.tensorlib.normal_cdf(get(cut)) for cut in cuts]
import matplotlib.pyplot as plt


plt.figure(dpi=120)



# plt.plot(cuts, significances, c="C2", label='scan of Asimov significance')

# plt.axhline(0.05,color='r',linestyle='--', label='$p_{\mu}$=0.05', alpha=0.4)
plt.plot(cuts_smooth, significances_smooth, label='scan of $p_{\mu}$')
plt.xlabel("cut position c")
plt.ylabel("$p_{\mu}$")

plt.axvline(cuts_smooth[np.argmin(significances_smooth)],linestyle=':', label='optimal cut')

print("the optimal cut is c =", cuts_smooth[np.argmin(significances_smooth)])
plt.legend()
plt.savefig('scanvfgegergds.png', bbox_inches='tight')
the optimal cut is c = 47.27272727272727
import matplotlib.pyplot as plt


plt.figure(dpi=120)


plt.xlim([40,60])
plt.ylim([0,0.05])
plt.plot(cuts, significances, c="C2", label='scan of Asimov significance')

# plt.axhline(0.05,color='r',linestyle='--', label='$p_{\mu}$=0.05', alpha=0.4)
plt.plot(cuts_smooth, significances_smooth, label='scan of $p_{\mu}$')
plt.xlabel("cut position c")
plt.ylabel("$p_{\mu}$")

plt.axvline(cuts_smooth[np.argmin(significances_smooth)],linestyle=':', label='optimal cut')

plt.legend();
def blobs(NMC=500, sig_mean = [-1, 1], b1_mean=[2.5, 2], b_mean=[1, -1], b2_mean=[-2.5, -1.5]):
    
    def generate_blobs():
        bkg_up = np.random.multivariate_normal(b1_mean, [[1, 0], [0, 1]], size=(NMC,))
        bkg_down = np.random.multivariate_normal(b2_mean, [[1, 0], [0, 1]], size=(NMC,))
        bkg_nom = np.random.multivariate_normal(b_mean, [[1, 0], [0, 1]], size=(NMC,))
        sig = np.random.multivariate_normal(sig_mean, [[1, 0], [0, 1]], size=(NMC,))
        
        return sig, bkg_nom, bkg_up, bkg_down
    
    return generate_blobs
def hists(data_generator, predict, bins, bandwidth, LUMI=10, sig_scale = 2, bkg_scale = 10):
    
    def hist_maker(nn):
        s, b_nom, b_up, b_down = data_generator()
        NMC = len(s)
        
        nn_s, nn_b_nom, nn_b_up, nn_b_down = (
            predict(nn, s).ravel(),
            predict(nn, b_nom).ravel(),
            predict(nn, b_up).ravel(),
            predict(nn, b_down).ravel(),
        )
             
        kde_counts = jax.numpy.asarray([
            smooth.hist(nn_s, bins, bandwidth) * sig_scale / NMC * LUMI,
            smooth.hist(nn_b_nom, bins, bandwidth) * bkg_scale / NMC * LUMI,
            smooth.hist(nn_b_up, bins, bandwidth) * bkg_scale / NMC * LUMI,
            smooth.hist(nn_b_down, bins, bandwidth) * bkg_scale / NMC * LUMI,
        ])
        
        return kde_counts
    
    return hist_maker
def nn_histosys(histogram_maker):
    
    def from_spec(yields):
        
        s, b, bup, bdown = yields
        
        spec = {
            "channels": [
                {
                    "name": "nn",
                    "samples": [
                        {
                            "name": "signal",
                            "data": s,
                            "modifiers": [
                                {"name": "mu", "type": "normfactor", "data": None}
                            ],
                        },
                        {
                            "name": "bkg",
                            "data": b,
                            "modifiers": [
                                {
                                    "name": "nn_histosys",
                                    "type": "histosys",
                                    "data": {
                                        "lo_data": bdown,
                                        "hi_data": bup,
                                    },
                                }
                            ],
                        },      
                    ],
                },
            ],
        }

        return pyhf.Model(spec)
    
    def nn_model_maker(nn):
        yields = histogram_maker(nn)
        m = from_spec(yields)
        nompars = m.config.suggested_init()
        bonlypars = jax.numpy.asarray([x for x in nompars])
        bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
        return m, bonlypars
    
    return nn_model_maker
import jax.experimental.stax as stax

# regression net
init_random_params, predict = stax.serial(
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1),
    stax.Sigmoid
)

# choose hyperparams
bins = np.linspace(0,1,4)
centers   = bins[:-1]  + np.diff(bins)/2.
bandwidth = 0.8 * 1/(len(bins)-1)

# compose functions to define workflow
data   = blobs()
hmaker = hists(data,predict,bins=bins,bandwidth=bandwidth)
model  = nn_histosys(hmaker)
loss   = smooth.infer.expected_pvalue_upper_limit(model, solver_kwargs=dict(pdf_transform=True))


_, network = init_random_params(jax.random.PRNGKey(13), (-1, 2))
jax.value_and_grad(loss, argnums=1)(1.0, network)
(DeviceArray(0.03419097, dtype=float64),
 [(DeviceArray([[-7.4747222e-05, -5.6944977e-05,  1.7446939e-05, ...,
                 -2.4487432e-05,  6.9174617e-05, -3.5466583e-05],
                [ 1.2072978e-04,  2.1503494e-05, -6.7116935e-06, ...,
                  7.1674481e-06, -9.7872784e-05, -1.3499416e-06]],            dtype=float32),
   DeviceArray([-4.6593788e-05, -2.4979156e-05, -1.8001270e-05, ...,
                 2.8620339e-05,  4.2742522e-05, -1.7730363e-05],            dtype=float32)),
  (),
  (DeviceArray([[-7.4381433e-07,  8.7469658e-08,  6.7387840e-10, ...,
                 -7.2362553e-08,  1.6154515e-06,  7.4637381e-07],
                [ 1.1795539e-07,  1.8024228e-07,  2.0555500e-09, ...,
                  0.0000000e+00,  1.3165605e-07, -6.8112592e-11],
                [-3.9922452e-07,  1.8413182e-10,  1.4307927e-10, ...,
                  1.6287530e-08, -5.0056929e-07,  4.5510586e-07],
                ...,
                [-3.5405691e-07, -1.1326374e-06, -1.4052329e-08, ...,
                  8.9071946e-08, -1.8953529e-06,  1.2026198e-06],
                [ 8.3317434e-07,  2.1319472e-06,  1.9097971e-08, ...,
                 -2.9123530e-09,  3.0891329e-06,  1.2002033e-07],
                [ 2.0260664e-07, -3.5106780e-08, -3.7393555e-10, ...,
                  0.0000000e+00,  8.3604625e-08, -9.1771426e-09]],            dtype=float32),
   DeviceArray([ 8.1915132e-06,  4.7153902e-05,  3.5778328e-07, ...,
                -6.6411008e-07,  8.6970671e-05,  1.8307695e-05],            dtype=float32)),
  (),
  (DeviceArray([[-2.1727724e-06],
                [-2.5328352e-05],
                [-1.3518069e-05],
                ...,
                [-8.2271417e-06],
                [-2.6815096e-04],
                [ 1.0678576e-05]], dtype=float32),
   DeviceArray([-0.00260607], dtype=float32)),
  ()])
import jax.experimental.optimizers as optimizers
import time

opt_init, opt_update, opt_params = optimizers.adam(1e-3)

@jax.jit
def update_and_value(i, opt_state, mu):
    net = opt_params(opt_state)
    value, grad = jax.value_and_grad(loss,argnums=1)(mu, net)
    return opt_update(i, grad, state), value, net

def train_network(N):
    cls_vals = []
    _, network = init_random_params(jax.random.PRNGKey(1), (-1, 2))
    state = opt_init(network)
    losses = []
    
    # parameter update function
    #@jax.jit
    def update_and_value(i, opt_state, mu):
        net = opt_params(opt_state)
        value, grad = jax.value_and_grad(loss,argnums=1)(mu, net)
        return opt_update(i, grad, state), value, net
    
    for i in range(N):
        start_time = time.time()
        state, value, network = update_and_value(i,state,1.0)
        epoch_time = time.time() - start_time
        losses.append(value)
        metrics = {"loss": losses}
        yield network, metrics, epoch_time
maxN = 50 # make me bigger for better results!

# Training
for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)):
    print(f"epoch {i}:", f'p_mu = {metrics["loss"][-1]:.5f}, took {epoch_time:.2f}s')
epoch 0: p_mu = 0.03416, took 1.71s
epoch 1: p_mu = 0.01664, took 2.05s
epoch 2: p_mu = 0.00774, took 2.03s
epoch 3: p_mu = 0.00423, took 2.05s
epoch 4: p_mu = 0.00264, took 2.05s
epoch 5: p_mu = 0.00185, took 2.06s
epoch 6: p_mu = 0.00142, took 2.11s
epoch 7: p_mu = 0.00116, took 2.02s
epoch 8: p_mu = 0.00100, took 2.01s
epoch 9: p_mu = 0.00089, took 2.06s
epoch 10: p_mu = 0.00081, took 2.02s
epoch 11: p_mu = 0.00076, took 2.12s
epoch 12: p_mu = 0.00072, took 2.13s
epoch 13: p_mu = 0.00069, took 2.13s
epoch 14: p_mu = 0.00068, took 2.05s
epoch 15: p_mu = 0.00067, took 2.07s
epoch 16: p_mu = 0.00066, took 2.03s
epoch 17: p_mu = 0.00064, took 2.06s
epoch 18: p_mu = 0.00063, took 2.06s
epoch 19: p_mu = 0.00062, took 2.05s
epoch 20: p_mu = 0.00060, took 2.06s
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-13-265cf9dd44a3> in <module>
      2 
      3 # Training
----> 4 for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)):
      5     print(f"epoch {i}:", f'p_mu = {metrics["loss"][-1]:.5f}, took {epoch_time:.2f}s')

<ipython-input-12-913bebde21bd> in train_network(N)
     26     for i in range(N):
     27         start_time = time.time()
---> 28         state, value, network = update_and_value(i,state,1.0)
     29         epoch_time = time.time() - start_time
     30         losses.append(value)

<ipython-input-12-913bebde21bd> in update_and_value(i, opt_state, mu)
     22         net = opt_params(opt_state)
     23         value, grad = jax.value_and_grad(loss,argnums=1)(mu, net)
---> 24         return opt_update(i, grad, state), value, net
     25 
     26     for i in range(N):

~/envs/neos/lib/python3.7/site-packages/jax/experimental/optimizers.py in tree_update(i, grad_tree, opt_state)
    152         raise TypeError(msg.format(tree, tree2))
    153       states = map(tree_unflatten, subtrees, packed_state)
--> 154       new_states = map(partial(update, i), grad_flat, states)
    155       new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states))
    156       for subtree, subtree2 in zip(subtrees, subtrees2):

~/envs/neos/lib/python3.7/site-packages/jax/util.py in safe_map(f, *args)
     32   for arg in args[1:]:
     33     assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
---> 34   return list(map(f, *args))
     35 
     36 def unzip2(xys):

~/envs/neos/lib/python3.7/site-packages/jax/experimental/optimizers.py in update(i, g, state)
    373     x, m, v = state
    374     m = (1 - b1) * g + b1 * m  # First  moment estimate.
--> 375     v = (1 - b2) * (g ** 2) + b2 * v  # Second moment estimate.
    376     mhat = m / (1 - b1 ** (i + 1))  # Bias correction.
    377     vhat = v / (1 - b2 ** (i + 1))

~/envs/neos/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in deferring_binary_op(self, other)
   3756     if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)):
   3757       return NotImplemented
-> 3758     return binary_op(self, other)
   3759   return deferring_binary_op
   3760 

~/envs/neos/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in power(x1, x2)
    603   dtype = _dtype(x1)
    604   if not issubdtype(dtype, integer):
--> 605     return lax.pow(x1, x2)
    606 
    607   # Integer power => use binary exponentiation.

~/envs/neos/lib/python3.7/site-packages/jax/lax/lax.py in pow(x, y)
    249 def pow(x: Array, y: Array) -> Array:
    250   r"""Elementwise power: :math:`x^y`."""
--> 251   return pow_p.bind(x, y)
    252 
    253 def sqrt(x: Array) -> Array:

~/envs/neos/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **kwargs)
    197     top_trace = find_top_trace(args)
    198     if top_trace is None:
--> 199       return self.impl(*args, **kwargs)
    200 
    201     tracers = map(top_trace.full_raise, args)

~/envs/neos/lib/python3.7/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    165   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
    166   compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
--> 167   return compiled_fun(*args)
    168 
    169 @cache()

~/envs/neos/lib/python3.7/site-packages/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, result_handler, *args)
    255   device, = compiled.local_devices()
    256   input_bufs = [device_put(x, device) for x in args if x is not token]
--> 257   out_bufs = compiled.Execute(input_bufs)
    258   if FLAGS.jax_debug_nans:
    259     check_nans(prim, out_bufs)

KeyboardInterrupt: