Tutorial
This tutorial demonstrates a simple application of BAT.jl: A Bayesian fit of a histogram with two Gaussian peaks.
You can also download this tutorial as a Jupyter notebook and a plain Julia source file.
Table of contents:
Note: This tutorial is somewhat verbose, as it aims to be easy to follow for users who are new to Julia. For the same reason, we deliberately avoid making use of Julia features like closures, anonymous functions, broadcasting syntax, performance annotations, etc.
Input Data Generation
First, let's generate some synthetic data to fit. We'll need the Julia standard-library packages "Random", "LinearAlgebra" and "Statistics", as well as the packages "Distributions" and "StatsBase":
using Random, LinearAlgebra, Statistics, Distributions, StatsBase
As the underlying truth of our input data/histogram, let us choose an non-normalized probability density composed of two Gaussian peaks with a peak area of 500 and 1000, a mean of -1.0 and 2.0 and a standard error of 0.5. So our model parameters will be:
par_names = ["a_1", "a_2", "mu_1", "mu_2", "sigma"]
true_par_values = [500, 1000, -1.0, 2.0, 0.5]
We'll define a function that returns two Gaussian distributions, based on a specific set of parameters
function model_distributions(parameters::AbstractVector{<:Real})
return (
Normal(parameters[3], parameters[5]),
Normal(parameters[4], parameters[5])
)
end
and then generate some synthetic data by drawing a number (according to the parameters a₁ and a₂) of samples from the two Gaussian distributions
data = vcat(
rand(model_distributions(true_par_values)[1], Int(true_par_values[1])),
rand(model_distributions(true_par_values)[2], Int(true_par_values[2]))
)
resulting in a vector of floating-point numbers:
typeof(data) == Vector{Float64}
true
Then we create a histogram of that data, this histogram will serve as the input for the Bayesian fit:
hist = append!(Histogram(-2:0.1:4), data)
StatsBase.Histogram{Int64,1,Tuple{StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}
edges:
-2.0:0.1:4.0
weights: [4, 8, 19, 12, 22, 27, 35, 38, 38, 30 … 11, 7, 5, 0, 1, 1, 1, 0, 0, 0]
closed: left
isdensity: false
The fit function that describes such a histogram (depending on the model parameters) will be
function fit_function(x::Real, parameters::AbstractVector{<:Real})
dists = model_distributions(parameters)
return parameters[1] * pdf(dists[1], x) +
parameters[2] * pdf(dists[2], x)
end
Using the Julia "Plots" package
using Plots
we can visually compare the histogram and the fit function, using the true values of the parameters:
plot(
normalize(hist, mode=:density),
st = :steps, label = "Data",
title = "Data and True Statistical Model"
)
plot!(
-4:0.01:4, x -> fit_function(x, true_par_values),
label = "Truth"
)
savefig("tutorial-data-and-truth.pdf")
Bayesian Fit
Now let's do a Bayesian fit of the generated histogram, using BAT.
In addition to the Julia packages loaded above, we need BAT itself, as well as IntervalSets:
using BAT, IntervalSets
Likelihood Definition
First, we need to define a likelihood function for our problem. In BAT, all likelihood functions and priors are subtypes of BAT.AbstractDensity
. We'll store the histogram that we want to fit in our likelihood density type, as accessing the histogram as a global variable would reduce performance:
struct HistogramLikelihood{H<:Histogram} <: AbstractDensity
histogram::H
end
As a minimum, BAT requires methods of BAT.nparams
and BAT.unsafe_density_logval
to be defined for each subtype of AbstractDensity
.
BAT.nparams
simply needs to return the number of free parameters:
BAT.nparams(likelihood::HistogramLikelihood) = 5
BAT.unsafe_density_logval
has to implement the actual log-likelihood function:
function BAT.unsafe_density_logval(
likelihood::HistogramLikelihood,
parameters::AbstractVector{<:Real},
exec_context::ExecContext
)
Histogram counts for each bin as an array:
counts = likelihood.histogram.weights
Histogram binning, has length (length(counts) + 1):
binning = likelihood.histogram.edges[1]
sum log-likelihood over bins:
log_likelihood::Float64 = 0.0
for i in eachindex(counts)
bin_left, bin_right = binning[i], binning[i+1]
bin_width = bin_right - bin_left
bin_center = (bin_right + bin_left) / 2
observed_counts = counts[i]
Simple mid-point rule integration of fit_function over bin:
expected_counts = bin_width * fit_function(bin_center, parameters)
log_likelihood += logpdf(Poisson(expected_counts), observed_counts)
end
return log_likelihood
end
Methods of BAT.unsafe_density_logval
may be "unsafe" insofar as the implementation is not required to check the length of the parameters
vector or the validity of the parameter values - BAT takes care of that (assuming that value provided by BAT.nparams
is correct and that the prior that will only cover valid parameter values).
Note: Currently, implementations of BAT.unsafedensitylogval must be type stable, to avoid triggering a Julia-internal error. The matter is under investigation. If the implementation of BAT.unsafe_density_logval
is not type-stable, this will often result in an error like this:
Internal error: encountered unexpected error in runtime:
MethodError(f=typeof(Core.Compiler.fieldindex)(), args=(Random123.Philox4x{T, R} ...
The exec_context
argument can be ignored in simple use cases, it is only of interest for unsafe_density_logval
methods that internally use Julia's multi-threading and/or distributed code execution capabilities.
BAT itself also makes use of Julia's parallel programming facilities. BAT can calculate log-density values in parallel (e.g. for multiple MCMC chains) on multiple threads (implemented) and support for distributed execution (on multiple hosts) is planned. By default, however, BAT will assume that implementations of BAT.unsafe_density_logval
are not thread safe. If your implementation is thread-safe (as is the case in the example above), you can advertise this fact to BAT:
BAT.exec_capabilities(::typeof(BAT.unsafe_density_logval), likelihood::HistogramLikelihood, parameters::AbstractVector{<:Real}) =
ExecCapabilities(0, true, 0, true)
BAT will then use multi-threaded log-likelihood evaluation where possible. Note that Julia starts only a single thread by default, you will need to set the environment variable JULIA_NUM_THREADS
to configure the number of Julia threads.
Given our fit function and the histogram to fit, we'll define the likelihood as
likelihood = HistogramLikelihood(hist)
Main.ex-tutorial.HistogramLikelihood{StatsBase.Histogram{Int64,1,Tuple{StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}}(StatsBase.Histogram{Int64,1,Tuple{StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}
edges:
-2.0:0.1:4.0
weights: [4, 8, 19, 12, 22, 27, 35, 38, 38, 30 … 11, 7, 5, 0, 1, 1, 1, 0, 0, 0]
closed: left
isdensity: false)
Prior Definition
For simplicity, we choose a flat prior, i.e. a normalized constant density:
prior = ConstDensity(
HyperRectBounds(
[
0.0..10.0^4, 0.0..10.0^4,
-2.0..0.0, 1.0..3.0,
0.3..0.7
],
reflective_bounds
),
normalize
)
In general, BAT allows instances of any subtype of AbstractDensity
to be uses as a prior, as long as a sampler is defined for it. This way, users may implement complex application-specific priors. You can also use convert(AbstractDensity, distribution)
to convert any continuous multivariate Distributions.Distribution
to a BAT.AbstractDensity
that can be used as a prior (or likelihood).
### Bayesian Model Definition
Given the likelihood and prior definition, a BAT.BayesianModel
is simply defined via
model = BayesianModel(likelihood, prior)
### Parameter Space Exploration via MCMC
We can now use Markov chain Monte Carlo (MCMC) to explore the space of possible parameter values for the histogram fit.
We'll use the Metropolis-Hastings algorithm and a multivariate t-distribution (ν = 1) as it's proposal function:
algorithm = MetropolisHastings(MvTDistProposalSpec(1.0))
We also need to which random number generator and seed to use. BAT requires a counter-based RNG and partitions the RNG space over the MCMC chains. This way, a single RNG seed is sufficient for all chains and results can be reproducible even under parallel execution. Let's choose a Philox4x RNG with a random seed:
rngseed = BAT.Philox4xSeed()
The algorithm, model and RNG seed specify the MCMC chains:
chainspec = MCMCSpec(algorithm, model, rngseed)
Let's use 4 MCMC chains and require 10^5 unique samples from each chain (after tuning/burn-in):
nsamples = 10^5
nchains = 4
BAT provides fine-grained control over the MCMC tuning algorithm, convergence test and the chain initialization and tuning/burn-in strategy (the values used here are the default values):
tuner_config = ProposalCovTunerConfig(
λ = 0.5,
α = 0.15..0.35,
β = 1.5,
c = 1e-4..1e2
)
convergence_test = BGConvergence(1.1)
init_strategy = MCMCInitStrategy(
ninit_tries_per_chain = 8..128,
max_nsamples_pretune = 25,
max_nsteps_pretune = 250,
max_time_pretune = Inf
)
burnin_strategy = MCMCBurninStrategy(
max_nsamples_per_cycle = 1000,
max_nsteps_per_cycle = 10000,
max_time_per_cycle = Inf,
max_ncycles = 30
)
Before running the Markov chains, let's set BAT's logging level to debug, to see what's going on in more detail (note: BAT's logging API will change in the future for better integration with the Julia v1 logging facilities):
BAT.Logging.set_log_level!(BAT, BAT.Logging.LOG_DEBUG)
Now we can generate a set of MCMC samples via rand
:
samples, sampleids, stats, chains = rand(
chainspec,
nsamples,
nchains,
tuner_config = tuner_config,
convergence_test = convergence_test,
init_strategy = init_strategy,
burnin_strategy = burnin_strategy,
max_nsteps = 10000,
max_time = Inf,
granularity = 1,
ll = BAT.Logging.LOG_INFO
)
Note: Reasonable default values are defined for all of the above. In many use cases, a simple
samples, sampleids, stats, chains =
rand(MCMCSpec(MetropolisHastings(), model), nsamples, nchains)`
may be sufficient.
Let's print some results:
println("Truth: $true_par_values")
println("Mode: $(stats.mode)")
println("Mean: $(stats.param_stats.mean)")
println("Covariance: $(stats.param_stats.cov)")
Truth: [500.0, 1000.0, -1.0, 2.0, 0.5]
Mode: [498.0598082015617, 999.0404429396266, -0.9945776591299531, 2.002663908037726, 0.49911629380054956]
Mean: [498.3928350826897, 1000.1758526981598, -0.9915499373702737, 2.002999757651183, 0.4975664939118484]
Covariance: [518.0167915798537 -7.2309968000202725 -0.03193193955549207 0.0015511904923729103 0.016406265992161353; -7.23099680002029 1042.108554095156 -0.006328099560325723 -0.006578951234396439 0.002570244006619241; -0.03193193955549208 -0.006328099560325721 0.0006173463049309837 4.201989470818367e-6 -3.6919197113793587e-5; 0.0015511904923728624 -0.006578951234396415 4.2019894708184785e-6 0.00024762227962288027 -3.6328353256593502e-6; 0.01640626599216136 0.0025702440066192483 -3.691919711379357e-5 -3.632835325659373e-6 9.456270262664063e-5]
stats
contains some statistics collected during MCMC sample generation, e.g. the mean and covariance of the parameters and the mode. Equal values for these statistics may of course be calculated afterwards, from the samples:
@assert vec(mean(samples.params, FrequencyWeights(samples.weight))) ≈ stats.param_stats.mean
@assert vec(var(samples.params, FrequencyWeights(samples.weight))) ≈ diag(stats.param_stats.cov)
@assert cov(samples.params, FrequencyWeights(samples.weight)) ≈ stats.param_stats.cov
We can also, e.g., get the Pearson auto-correlation of the parameters:
vec(cor(samples.params, FrequencyWeights(samples.weight)))
25-element Array{Float64,1}:
1.0
-0.009841692014177653
-0.056466250849516106
0.004331101450403988
0.07412725047635878
-0.009841692014177653
1.0
-0.007889552970893643
-0.012951060380573549
0.008187624874464406
⋮
-0.012951060380573549
0.010747208503865381
1.0
-0.023740557610138997
0.07412725047635878
0.008187624874464406
-0.15280159348700828
-0.023740557610138997
1.0
Visualization of Results
BAT.jl comes with an extensive set of plotting recipes for "Plots.jl". We can plot the marginalized distribution for a single parameter (e.g. parameter 3, i.e. μ₁):
plot(
samples, 3,
mean = true, std_dev = true, globalmode = true, localmode = true,
nbins = 50, xlabel = par_names[3], ylabel = "P($(par_names[3]))",
title = "Marginalized Distribution for mu_1"
)
savefig("tutorial-single-par.pdf")
or plot the marginalized distribution for a pair of parameters (e.g. parameters 3 and 5, i.e. μ₁ and σ), including information from the parameter stats:
plot(
samples, (3, 5),
mean = true, std_dev = true, globalmode = true, localmode = true,
nbins = 50, xlabel = par_names[3], ylabel = par_names[5],
title = "Marginalized Distribution for mu_1 and sigma"
)
plot!(stats, (3, 5))
savefig("tutorial-param-pair.pdf")
We can also create an overview plot of the marginalized distribution for all pairs of parameters:
plot(
samples,
mean = false, std_dev = false, globalmode = true, localmode = false,
nbins = 50
)
savefig("tutorial-all-params.pdf")
Integration with Tables.jl
BAT.jl supports the Tables.jl interface. Using a tables implementation like TypedTables.jl](http://blog.roames.com/TypedTables.jl/stable/), the whole MCMC output (parameter vectors, weights, sample/chain numbers, etc.) can easily can be combined into a single table:
using TypedTables
tbl = Table(samples, sampleids)
Table with 8 columns and 9483 rows:
params log_posterior log_prior weight chainid ⋯
┌───────────────────────────────────────────────────────────────────
1 │ [490.204, 980.0, -0… -176.701 -18.8907 2 6 ⋯
2 │ [512.784, 992.206, … -176.86 -18.8907 6 6 ⋯
3 │ [510.166, 981.349, … -176.817 -18.8907 1 6 ⋯
4 │ [493.005, 1005.02, … -175.134 -18.8907 3 6 ⋯
5 │ [503.889, 1003.42, … -174.583 -18.8907 1 6 ⋯
6 │ [511.166, 975.166, … -175.406 -18.8907 10 6 ⋯
7 │ [505.344, 993.121, … -175.786 -18.8907 6 6 ⋯
8 │ [482.257, 1007.72, … -174.112 -18.8907 1 6 ⋯
9 │ [491.073, 1019.96, … -174.493 -18.8907 8 6 ⋯
10 │ [502.594, 1016.6, -… -174.693 -18.8907 8 6 ⋯
11 │ [495.741, 988.057, … -174.258 -18.8907 11 6 ⋯
12 │ [530.43, 985.626, -… -174.859 -18.8907 4 6 ⋯
13 │ [516.224, 976.529, … -174.031 -18.8907 5 6 ⋯
14 │ [493.15, 1005.96, -… -173.365 -18.8907 2 6 ⋯
15 │ [480.093, 976.688, … -175.941 -18.8907 2 6 ⋯
16 │ [508.584, 997.481, … -176.416 -18.8907 2 6 ⋯
17 │ [483.116, 970.145, … -176.878 -18.8907 3 6 ⋯
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱
We can now, e.g., find the sample with the maximum posterior value (i.e. the mode):
mode_log_posterior, mode_idx = findmax(tbl.log_posterior)
(-173.30488843260017, 4508)
And get row mode_idx
of the table, with all information about the sample at the mode:
Comparison of Truth and Best Fit
As a final step, we retrieve the parameter values at the mode, representing the best-fit parameters
fit_par_values = tbl[mode_idx].params
5-element Array{Float64,1}:
498.0598082015617
999.0404429396266
-0.9945776591299531
2.002663908037726
0.49911629380054956
And plot the truth, data, and best fit:
plot(
normalize(hist, mode=:density),
st = :steps, label = "Data",
title = "Data, True Model and Best Fit"
)
plot!(-4:0.01:4, x -> fit_function(x, true_par_values), label = "Truth")
plot!(-4:0.01:4, x -> fit_function(x, fit_par_values), label = "Best fit")
savefig("tutorial-data-truth-bestfit.pdf")
This page was generated using Literate.jl.