Test
A tutorial of bleh.
import jax
import jax.numpy as jnp
import numpy as np
import pyhf
# use jax backend of pyhf
pyhf.set_backend(pyhf.tensor.jax_backend())
import smooth
import smooth.infer
nBg = 8000
nSig = 300
background = np.random.normal(40, 10, nBg)
signal = np.random.normal(50, 5, nSig)
def generate_data():
return signal, background
import matplotlib.pyplot as plt
bins = np.linspace(0, 80, 40)
plt.figure(dpi=120)
ax = plt.gca()
plt.hist([background, signal], bins=bins, stacked=True, label=["background", "signal"])
ax.set(ylabel='frequency',xlabel='x')
plt.legend();
def preprocess(data_generator):
def counts(cut_param):
s, b = data_generator()
s_counts = smooth.cut(s,'>',cut_param).sum()
b_counts = smooth.cut(b,'>',cut_param).sum()
return jnp.array([s_counts]), jnp.array([b_counts])
return counts
def simple_histosys(data_with_cuts, uncert):
"""
Makes a histosys model with up/down yields of +- bkg*(uncert/2).
"""
def from_spec(yields):
s, b = yields
bup, bdown = b * (1 + (uncert / 2)), b * (1 - (uncert / 2))
spec = {
"channels": [
{
"name": "smoothcut",
"samples": [
{
"name": "signal",
"data": s,
"modifiers": [
{"name": "mu", "type": "normfactor", "data": None}
],
},
{
"name": "bkg",
"data": b,
"modifiers": [
{
"name": "artificial_histosys",
"type": "histosys",
"data": {"lo_data": bdown, "hi_data": bup,},
}
],
},
],
},
],
}
return pyhf.Model(spec)
def model_maker(cut_pars):
s, b = data_with_cuts(cut_pars)
# make statistical model with pyhf
m = from_spec([s, b])
nompars = m.config.suggested_init()
bonlypars = jnp.asarray([x for x in nompars])
bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
return m, bonlypars
return model_maker
data_with_cuts = preprocess(generate_data)
model_maker = simple_histosys(data_with_cuts, uncert=0.05)
loss = smooth.infer.expected_pvalue_upper_limit(model_maker,
solver_kwargs=dict(pdf_transform=True))
# loss2 = infer.upper_lim(loss)
jax.value_and_grad(loss)(1.,40.)
loss2(40.)
l=[]
mus = np.linspace(0,1,100)
[l.append(loss2(40.)(mu)) for mu in mus]
import matplotlib.pyplot as plt
plt.figure(dpi=120)
plt.plot(mus,l)
min(l)
def scn(c):
return loss(c,1.)
cuts_smooth = np.linspace(20, 70, 100)
significances_smooth = np.asarray([scn(cut) for cut in cuts_smooth])
def significance(S, B):
return jnp.sqrt(2*((S+B)*jnp.log(1+S/B)-S))
def objective(data_with_cuts):
def significance(S, B):
return jnp.sqrt(2*((S+B)*jnp.log(1+S/B)-S))
def get_significance_smooth(cut_pars):
'''
Our loss function! Takes in cut parameters, returns significance
'''
s, b = data_with_cuts(cut_pars)
return significance(s[0], b[0])
return get_significance_smooth
get = objective(preprocess(generate_data))
cuts = np.linspace(20, 70, 500)
significances = [1 - pyhf.tensorlib.normal_cdf(get(cut)) for cut in cuts]
import matplotlib.pyplot as plt
plt.figure(dpi=120)
# plt.plot(cuts, significances, c="C2", label='scan of Asimov significance')
# plt.axhline(0.05,color='r',linestyle='--', label='$p_{\mu}$=0.05', alpha=0.4)
plt.plot(cuts_smooth, significances_smooth, label='scan of $p_{\mu}$')
plt.xlabel("cut position c")
plt.ylabel("$p_{\mu}$")
plt.axvline(cuts_smooth[np.argmin(significances_smooth)],linestyle=':', label='optimal cut')
print("the optimal cut is c =", cuts_smooth[np.argmin(significances_smooth)])
plt.legend()
plt.savefig('scanvfgegergds.png', bbox_inches='tight')
import matplotlib.pyplot as plt
plt.figure(dpi=120)
plt.xlim([40,60])
plt.ylim([0,0.05])
plt.plot(cuts, significances, c="C2", label='scan of Asimov significance')
# plt.axhline(0.05,color='r',linestyle='--', label='$p_{\mu}$=0.05', alpha=0.4)
plt.plot(cuts_smooth, significances_smooth, label='scan of $p_{\mu}$')
plt.xlabel("cut position c")
plt.ylabel("$p_{\mu}$")
plt.axvline(cuts_smooth[np.argmin(significances_smooth)],linestyle=':', label='optimal cut')
plt.legend();
def blobs(NMC=500, sig_mean = [-1, 1], b1_mean=[2.5, 2], b_mean=[1, -1], b2_mean=[-2.5, -1.5]):
def generate_blobs():
bkg_up = np.random.multivariate_normal(b1_mean, [[1, 0], [0, 1]], size=(NMC,))
bkg_down = np.random.multivariate_normal(b2_mean, [[1, 0], [0, 1]], size=(NMC,))
bkg_nom = np.random.multivariate_normal(b_mean, [[1, 0], [0, 1]], size=(NMC,))
sig = np.random.multivariate_normal(sig_mean, [[1, 0], [0, 1]], size=(NMC,))
return sig, bkg_nom, bkg_up, bkg_down
return generate_blobs
def hists(data_generator, predict, bins, bandwidth, LUMI=10, sig_scale = 2, bkg_scale = 10):
def hist_maker(nn):
s, b_nom, b_up, b_down = data_generator()
NMC = len(s)
nn_s, nn_b_nom, nn_b_up, nn_b_down = (
predict(nn, s).ravel(),
predict(nn, b_nom).ravel(),
predict(nn, b_up).ravel(),
predict(nn, b_down).ravel(),
)
kde_counts = jax.numpy.asarray([
smooth.hist(nn_s, bins, bandwidth) * sig_scale / NMC * LUMI,
smooth.hist(nn_b_nom, bins, bandwidth) * bkg_scale / NMC * LUMI,
smooth.hist(nn_b_up, bins, bandwidth) * bkg_scale / NMC * LUMI,
smooth.hist(nn_b_down, bins, bandwidth) * bkg_scale / NMC * LUMI,
])
return kde_counts
return hist_maker
def nn_histosys(histogram_maker):
def from_spec(yields):
s, b, bup, bdown = yields
spec = {
"channels": [
{
"name": "nn",
"samples": [
{
"name": "signal",
"data": s,
"modifiers": [
{"name": "mu", "type": "normfactor", "data": None}
],
},
{
"name": "bkg",
"data": b,
"modifiers": [
{
"name": "nn_histosys",
"type": "histosys",
"data": {
"lo_data": bdown,
"hi_data": bup,
},
}
],
},
],
},
],
}
return pyhf.Model(spec)
def nn_model_maker(nn):
yields = histogram_maker(nn)
m = from_spec(yields)
nompars = m.config.suggested_init()
bonlypars = jax.numpy.asarray([x for x in nompars])
bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)
return m, bonlypars
return nn_model_maker
import jax.experimental.stax as stax
# regression net
init_random_params, predict = stax.serial(
stax.Dense(1024),
stax.Relu,
stax.Dense(1024),
stax.Relu,
stax.Dense(1),
stax.Sigmoid
)
# choose hyperparams
bins = np.linspace(0,1,4)
centers = bins[:-1] + np.diff(bins)/2.
bandwidth = 0.8 * 1/(len(bins)-1)
# compose functions to define workflow
data = blobs()
hmaker = hists(data,predict,bins=bins,bandwidth=bandwidth)
model = nn_histosys(hmaker)
loss = smooth.infer.expected_pvalue_upper_limit(model, solver_kwargs=dict(pdf_transform=True))
_, network = init_random_params(jax.random.PRNGKey(13), (-1, 2))
jax.value_and_grad(loss, argnums=1)(1.0, network)
import jax.experimental.optimizers as optimizers
import time
opt_init, opt_update, opt_params = optimizers.adam(1e-3)
@jax.jit
def update_and_value(i, opt_state, mu):
net = opt_params(opt_state)
value, grad = jax.value_and_grad(loss,argnums=1)(mu, net)
return opt_update(i, grad, state), value, net
def train_network(N):
cls_vals = []
_, network = init_random_params(jax.random.PRNGKey(1), (-1, 2))
state = opt_init(network)
losses = []
# parameter update function
#@jax.jit
def update_and_value(i, opt_state, mu):
net = opt_params(opt_state)
value, grad = jax.value_and_grad(loss,argnums=1)(mu, net)
return opt_update(i, grad, state), value, net
for i in range(N):
start_time = time.time()
state, value, network = update_and_value(i,state,1.0)
epoch_time = time.time() - start_time
losses.append(value)
metrics = {"loss": losses}
yield network, metrics, epoch_time
maxN = 50 # make me bigger for better results!
# Training
for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)):
print(f"epoch {i}:", f'p_mu = {metrics["loss"][-1]:.5f}, took {epoch_time:.2f}s')