Fitting multiple datasets using MultiFitter¶
Often we have multiple datasets for a Miyake event. Instead of fitting a model to each dataset indivudually, we can use a MultiFitter object to fit to all the datasets at once! Here's how we can do it,
import numpy as np
import matplotlib.pyplot as plt
import ticktack
import jax.numpy as jnp
from jax import jit
from ticktack import fitting
import matplotlib as mpl'seaborn-colorblind')
First, we need to make a directory and place our datasets inside it. Each dataset should be a csv file, and has the same format as example.csv
Example 1: Fit a sinusoid model to the 775CE
Miyake event with MCMC
mf, sampler = fitting.fit_event(year=775, # the year of the event
params=jnp.array([775.1, -1., 1.6, np.log10(6.5)]), # start date, duration, phase, area
low_bounds=jnp.array([770., -1.5, 0, -2]),
up_bounds=jnp.array([780., 1.5, 11, 1.5]),
burnin=500, production=500)
samples = sampler.copy()
samples[:,1] = 10**samples[:,1] # duration not log duration
samples[:,-1] = 10**samples[:,-1] # area not log area
labels = ["start date (yr)", "duration (yr)", "$\phi$ (yr)", "Q (atoms/cm$^2$ yr/s)"]
mf.chain_summary(samples, walkers=8, labels=labels); # by default walker = dim(params) * 2
This is a rough estimate of the posterior of these parameters. We can get a much better corner plot if the burn-in steps and the production steps are higher (say 1000).
We can also define a custom parametric production model for fitting a MultiFitter. Here we will use a simple flat production function with a super gaussian spike. The only parameters for this model are: start date of the event, duration of the event, spike production
Example 2: Fit a custom model to the 775CE
Miyake event with MCMC
def super_gaussian(t, start_time, duration, area):
middle = start_time + duration / 2.
height = area / duration
return height * jnp.exp(- ((t - middle) / (1. / 1.93516 * duration)) ** 16.)
def super_gaussian_pulse(t, *args):
start_time, duration, area = jnp.array(list(args)).reshape(-1)
height = super_gaussian(t, start_time, duration, area)
production = 1.76 + height # 1.76 is the default steady state production rate
return production
mf, sampler = fitting.fit_event(year=775,
params=jnp.array([775.1, 0.3, 6.5]),
low_bounds=jnp.array([770., 0., 0.]),
up_bounds=jnp.array([780., 5., 15.]),
burnin=500, production=500)
samples = sampler.copy()
samples[:,1] = 10**samples[:,1] # duration not log duration
samples[:,-1] = 10**samples[:,-1] # area not log area
labels = ["start date (yr)", "duration (yr)", "Q (atoms/cm$^2$ yr/s)"]
mf.chain_summary(sampler, walkers=6, labels=labels);
