Tutorial

This tutorial demonstrates a simple application of BAT.jl: A Bayesian fit of a histogram with two Gaussian peaks.

You can also download this tutorial as a Jupyter notebook and a plain Julia source file.

Table of contents:

Note: This tutorial is somewhat verbose, as it aims to be easy to follow for users who are new to Julia. For the same reason, we deliberately avoid making use of Julia features like closures, anonymous functions, broadcasting syntax, performance annotations, etc.

Input Data Generation

First, let's generate some synthetic data to fit. We'll need the Julia standard-library packages "Random", "LinearAlgebra" and "Statistics", as well as the packages "Distributions" and "StatsBase":

using Random, LinearAlgebra, Statistics, Distributions, StatsBase

As the underlying truth of our input data/histogram, let us choose the expected count to follow the sum of two Gaussian peaks with peak areas of 500 and 1000, a mean of -1.0 and 2.0 and a standard error of 0.5. Then

data = vcat(
    rand(Normal(-1.0, 0.5), 500),
    rand(Normal( 2.0, 0.5), 1000)
)
1500-element Vector{Float64}:
 -0.9411166400318934
 -0.9589059755500976
 -1.1489688334533168
 -0.10117180993609187
 -1.4632431965209574
 -0.5565813545618588
 -0.09563101154101672
 -0.3219908437251632
 -1.1397945929041011
 -1.086203224625344
  ⋮
  1.5360832834208233
  1.2946613305718468
  2.3336353665508422
  1.8676080858396855
  2.3250772714720083
  2.100152512104538
  1.587457662817556
  1.8301847800172673
  2.4941695216110804

resulting in a vector of floating-point numbers:

typeof(data) == Vector{Float64}
true

Next, we'll create a histogram of that data, this histogram will serve as the input for the Bayesian fit:

hist = append!(Histogram(-2:0.1:4), data)
StatsBase.Histogram{Int64, 1, Tuple{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}
edges:
  -2.0:0.1:4.0
weights: [1, 8, 8, 20, 17, 32, 36, 42, 33, 38  …  8, 3, 0, 1, 0, 1, 0, 0, 0, 0]
closed: left
isdensity: false

Using the Julia "Plots" package

using Plots

we can plot the histogram:

plot(
    normalize(hist, mode=:density),
    st = :steps, label = "Data",
    title = "Data"
)
savefig("tutorial-data.pdf")

Data

Let's define our fit function - the function that we expect to describe the data histogram, at each x-Axis position x, depending on a given set p of model parameters:

function fit_function(p::NamedTuple{(:a, :mu, :sigma)}, x::Real)
    p.a[1] * pdf(Normal(p.mu[1], p.sigma), x) +
    p.a[2] * pdf(Normal(p.mu[2], p.sigma), x)
end

The fit parameters (model parameters) a (peak areas) and mu (peak means) are vectors, parameter sigma (peak width) is a scalar, we assume it's the same for both Gaussian peaks.

The true values for the model/fit parameters are the values we used to generate the data:

true_par_values = (a = [500, 1000], mu = [-1.0, 2.0], sigma = 0.5)

Let's visually compare the histogram and the fit function, using these true parameter values, to make sure everything is set up correctly:

plot(
    normalize(hist, mode=:density),
    st = :steps, label = "Data",
    title = "Data and True Statistical Model"
)
plot!(
    -4:0.01:4, x -> fit_function(true_par_values, x),
    label = "Truth"
)
savefig("tutorial-data-and-truth.pdf")

Data and True Statistical Model

Bayesian Fit

Now we'll perform a Bayesian fit of the generated histogram, using BAT, to infer the model parameters from the data histogram.

In addition to the Julia packages loaded above, we need BAT itself, as well as IntervalSets:

using BAT, DensityInterface, IntervalSets

Likelihood Definition

First, we need to define the likelihood for our problem.

BAT expects likelihoods to implements the DensityInterface API. We can simply wrap a log-likelihood function with DensityInterface.logfuncdensity to make it compatible.

For performance reasons, functions should not access global variables directly. So we'll use an anonymous function inside of a let-statement to capture the value of the global variable hist in a local variable h (and to shorten function name fit_function to f, purely for convenience). DensityInterface.logfuncdensity then turns the log-likelihood function into a DensityInterface density object.

likelihood = let h = hist, f = fit_function
    # Histogram counts for each bin as an array:
    observed_counts = h.weights

    # Histogram binning:
    bin_edges = h.edges[1]
    bin_edges_left = bin_edges[1:end-1]
    bin_edges_right = bin_edges[2:end]
    bin_widths = bin_edges_right - bin_edges_left
    bin_centers = (bin_edges_right + bin_edges_left) / 2

    logfuncdensity(function (params)
        # Log-likelihood for a single bin:
        function bin_log_likelihood(i)
            # Simple mid-point rule integration of fit function `f` over bin:
            expected_counts = bin_widths[i] * f(params, bin_centers[i])
            # Avoid zero expected counts for numerical stability:
            logpdf(Poisson(expected_counts + eps(expected_counts)), observed_counts[i])
        end

        # Sum log-likelihood over bins:
        idxs = eachindex(observed_counts)
        ll_value = bin_log_likelihood(idxs[1])
        for i in idxs[2:end]
            ll_value += bin_log_likelihood(i)
        end

        return ll_value
    end)
end
LogFuncDensity(Main.var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(Main.fit_function)}(-1.95:0.1:3.95, StepRangeLen(0.1, 0.0, 60), [1, 8, 8, 20, 17, 32, 36, 42, 33, 38  …  8, 3, 0, 1, 0, 1, 0, 0, 0, 0], Main.fit_function))

BAT makes use of Julia's parallel programming facilities if possible, e.g. to run multiple Markov chains in parallel. Therefore, log-likelihood (and other) code must be thread-safe. Mark non-thread-safe code with @critical (provided by Julia package ParallelProcessingTools).

Support for automatic parallelization across multiple (local and remote) Julia processes is planned, but not implemented yet.

Note that Julia currently starts only a single thread by default. Set the the environment variable JULIA_NUM_THREADS to specify the desired number of Julia threads.

We can evaluate likelihood, e.g. at the true parameter values:

logdensityof(likelihood, true_par_values)
-160.34429524301976

Prior Definition

Next, we need to choose a sensible prior for the fit:

prior = distprod(
    a = [Weibull(1.1, 5000), Weibull(1.1, 5000)],
    mu = [-2.0..0.0, 1.0..3.0],
    sigma = Weibull(1.2, 2)
)

BAT supports most Distributions.Distribution types, and combinations of them, as priors.

Bayesian Model Definition

Given the likelihood and prior definition, a BAT.PosteriorMeasure is simply defined via

posterior = PosteriorMeasure(likelihood, prior)

Parameter Space Exploration via MCMC

We can now use Markov chain Monte Carlo (MCMC) to explore the space of possible parameter values for the histogram fit.

To increase the verbosity level of BAT logging output, you may want to set the Julia logging level for BAT to debug via ENV["JULIA_DEBUG"] = "BAT".

Now we can generate a set of MCMC samples via bat_sample. We'll use 4 MCMC chains with 10^5 MC steps in each chain (after tuning/burn-in):

samples = bat_sample(posterior, MCMCSampling(mcalg = MetropolisHastings(), nsteps = 10^5, nchains = 4)).result
[ Info: Setting new default BAT context BATContext{Float64}(Random123.Philox4x{UInt64, 10}(0x935c36a55c9cd405, 0x1241442c3c8ad7eb, 0xcb97d95cd73aebb6, 0xb664ea9a1ac85db8, 0xb829a7c827b7ee6c, 0x94abb1747112445f, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0), HeterogeneousComputing.CPUnit(), BAT._NoADSelected())
[ Info: MCMCChainPoolInit: trying to generate 4 viable MCMC chain(s).
[ Info: Selected 4 MCMC chain(s).
[ Info: Begin tuning of 4 MCMC chain(s).
[ Info: MCMC Tuning cycle 1 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 2 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 3 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 4 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 5 finished, 4 chains, 0 tuned, 4 converged.
[ Info: MCMC Tuning cycle 6 finished, 4 chains, 0 tuned, 4 converged.
[ Info: MCMC Tuning cycle 7 finished, 4 chains, 0 tuned, 4 converged.
[ Info: MCMC Tuning cycle 8 finished, 4 chains, 1 tuned, 4 converged.
[ Info: MCMC Tuning cycle 9 finished, 4 chains, 2 tuned, 4 converged.
[ Info: MCMC Tuning cycle 10 finished, 4 chains, 4 tuned, 4 converged.
[ Info: MCMC tuning of 4 chains successful after 10 cycle(s).
[ Info: Running post-tuning stabilization steps for 4 MCMC chain(s).

Let's calculate some statistics on the posterior samples:

println("Truth: $true_par_values")
println("Mode: $(mode(samples))")
println("Mean: $(mean(samples))")
println("Stddev: $(std(samples))")
Truth: (a = [500, 1000], mu = [-1.0, 2.0], sigma = 0.5)
Mode: (a = [494.61869647378444, 999.8506929416998], mu = [-0.9635596122369079, 1.974893706915037], sigma = 0.49316547136847655)
Mean: (a = [496.5826891075374, 1000.4979655647167], mu = [-0.966971334324237, 1.9767947891508277], sigma = 0.4941963073240294)
Stddev: (a = [22.570165499538053, 31.762679315680273], mu = [0.02446515914846366, 0.015737958737598694], sigma = 0.009637938174520615)

Internally, BAT often needs to represent variates as flat real-valued vectors:

unshaped_samples, f_flatten = bat_transform(Vector, samples)
(result = DensitySampleVector(length = 109188, varshape = ValueShapes.ArrayShape{Float64, 1}((5,))), trafo = Base.Fix2{typeof(ValueShapes.unshaped), ValueShapes.NamedTupleShape{(:a, :mu, :sigma), Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}}(ValueShapes.unshaped, NamedTupleShape((a = ValueShapes.ArrayShape{Real, 1}((2,)), mu = ValueShapes.ArrayShape{Real, 1}((2,)), sigma = ValueShapes.ScalarShape{Real}()))), optargs = (algorithm = BAT.UnshapeTransformation(), context = BATContext{Float64}(Random123.Philox4x{UInt64, 10}(0x8f0c648c79c75a7d, 0xd82945f5b32dff4d, 0xc0223866bde46b6b, 0xf310effee32cef95, 0xb829a7c827b7ee6c, 0x94abb1747112445f, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x8000020100000000, 0), HeterogeneousComputing.CPUnit(), BAT._NoADSelected())))

The statisics above (mode, mean and std-dev) are presented in shaped form. However, it's not possible to represent statistics with matrix shape, e.g. the parameter covariance matrix, this way. So the covariance has to be accessed in unshaped form:

par_cov = cov(unshaped_samples)
println("Covariance: $par_cov")
Covariance: [509.41237067654663 -2.9766319844569193 -0.022037329313728236 0.0017539557217750492 0.009336362056852222; -2.9766319844569193 1008.8677973107552 -0.004566474904881762 -0.004141033234595744 0.0001302359425004363; -0.022037329313728236 -0.004566474904881762 0.0005985440121596674 9.917299930893005e-6 -2.9676094237243733e-5; 0.0017539557217750492 -0.004141033234595744 9.917299930893005e-6 0.0002476833452263561 8.569265618043097e-8; 0.009336362056852222 0.0001302359425004363 -2.9676094237243733e-5 8.569265618043097e-8 9.288985225588003e-5]

Use bat_report to generate an overview of the sampling result and parameter estimates (based on the marginal distributions):

bat_report(samples)

Sampling result

  • Total number of samples: 109188

  • Total weight of samples: 399996

  • Effective sample size: between 2507 and 11290

Marginals

ParameterMeanStd. dev.Gobal modeMarg. modeCred. intervalHistogram
a[1]496.58322.5702494.619490.0472.77 .. 517.986⠀⠀⠀⠀⠀405[⠀⠀⠀⠀⠀⠀⠀⠀⠀▁▁▂▃▄▅▆▇█████▆▅▄▃▂▂▁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[595⠀⠀⠀⠀⠀
a[2]1000.531.7627999.8511010.0966.863 .. 1030.79⠀⠀⠀⠀⠀884[⠀⠀⠀⠀⠀⠀⠀⠀▁▂▂▃▅▅▆▇████▇▆▅▄▃▂▁▁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[1.14e+03
mu[1]-0.9669710.0244652-0.96356-0.97-0.98998 .. -0.941155⠀⠀⠀-1.07[⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀▁▁▂▃▄▅▆▇████▇▆▅▄▃▂▁▁⠀⠀⠀⠀⠀⠀⠀⠀⠀[-0.869⠀⠀
mu[2]1.976790.0157381.974891.9751.96021 .. 1.99173⠀⠀⠀⠀1.91[⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀▁▁▂▃▄▆▇█████▇▆▄▄▂▂▁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[2.04⠀⠀⠀⠀
sigma0.4941960.009637940.4931650.49250.484168 .. 0.503357⠀⠀⠀0.452[⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀▁▂▂▄▅▆▇████▇▆▅▄▃▂▁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[0.54⠀⠀⠀⠀

Visualization of Results

BAT.jl comes with an extensive set of plotting recipes for "Plots.jl". We can plot the marginalized distribution for a single parameter (e.g. parameter 3, i.e. μ[1]):

plot(
    samples, :(mu[1]),
    mean = true, std = true, globalmode = true, marginalmode = true,
    nbins = 50, title = "Marginalized Distribution for mu[1]"
)
savefig("tutorial-single-par.pdf")

Marginalized Distribution for mu_1

or plot the marginalized distribution for a pair of parameters (e.g. parameters 3 and 5, i.e. μ[1] and σ), including information from the parameter stats:

plot(
    samples, (:(mu[1]), :sigma),
    mean = true, std = true, globalmode = true, marginalmode = true,
    nbins = 50, title = "Marginalized Distribution for mu[1] and sigma"
)
plot!(BAT.MCMCBasicStats(samples), (3, 5))
savefig("tutorial-param-pair.png")

Marginalized Distribution for mu_1 and sigma

We can also create an overview plot of the marginalized distribution for all pairs of parameters:

plot(
    samples,
    mean = false, std = false, globalmode = true, marginalmode = false,
    nbins = 50
)
savefig("tutorial-all-params.png")

Pairwise Correlation between Parameters

Integration with Tables.jl

DensitySamplesVector supports the Tables.jl interface, so it is a table itself. We can also convert it to other table types, e.g. a TypedTables.Table:

using TypedTables

tbl = Table(samples)
Table with 5 columns and 109188 rows:
      v                       logd      weight  info                    aux
    ┌──────────────────────────────────────────────────────────────────────────
 1  │ (a = [481.558, 1065.1…  -181.562  10      MCMCSampleID(4, 13, 0…  nothing
 2  │ (a = [477.711, 1050.2…  -180.798  3       MCMCSampleID(4, 13, 1…  nothing
 3  │ (a = [503.682, 1019.8…  -179.147  2       MCMCSampleID(4, 13, 1…  nothing
 4  │ (a = [503.828, 1030.0…  -180.914  1       MCMCSampleID(4, 13, 1…  nothing
 5  │ (a = [488.682, 1009.4…  -180.678  3       MCMCSampleID(4, 13, 1…  nothing
 6  │ (a = [487.468, 1029.7…  -180.556  2       MCMCSampleID(4, 13, 1…  nothing
 7  │ (a = [492.917, 1016.5…  -179.965  1       MCMCSampleID(4, 13, 2…  nothing
 8  │ (a = [492.049, 1011.7…  -181.037  6       MCMCSampleID(4, 13, 2…  nothing
 9  │ (a = [483.661, 1012.7…  -181.862  6       MCMCSampleID(4, 13, 2…  nothing
 10 │ (a = [494.953, 1016.6…  -181.153  1       MCMCSampleID(4, 13, 3…  nothing
 11 │ (a = [500.204, 1035.0…  -179.406  9       MCMCSampleID(4, 13, 3…  nothing
 12 │ (a = [503.086, 1043.8…  -181.26   3       MCMCSampleID(4, 13, 4…  nothing
 13 │ (a = [498.634, 1038.5…  -179.953  6       MCMCSampleID(4, 13, 4…  nothing
 14 │ (a = [496.343, 1039.9…  -179.597  7       MCMCSampleID(4, 13, 5…  nothing
 15 │ (a = [491.895, 1042.7…  -179.953  4       MCMCSampleID(4, 13, 6…  nothing
 16 │ (a = [498.185, 1045.5…  -179.427  4       MCMCSampleID(4, 13, 6…  nothing
 17 │ (a = [493.392, 1037.5…  -178.841  14      MCMCSampleID(4, 13, 6…  nothing
 ⋮  │           ⋮                ⋮        ⋮               ⋮                ⋮

or a DataFrames.DataFrame, etc.

Comparison of Truth and Best Fit

As a final step, we retrieve the parameter values at the mode, representing the best-fit parameters

samples_mode = mode(samples)
(a = [494.61869647378444, 999.8506929416998], mu = [-0.9635596122369079, 1.974893706915037], sigma = 0.49316547136847655)

Like the samples themselves, the result can be viewed in both shaped and unshaped form. samples_mode is presented as a 0-dimensional array that contains a NamedTuple, this representation preserves the shape information:

samples_mode isa NamedTuple
true

samples_mode is only an estimate of the mode of the posterior distribution. It can be further refined using bat_findmode:

using Optim

findmode_result = bat_findmode(
    posterior,
    OptimAlg(optalg = Optim.NelderMead(), init = ExplicitInit([samples_mode]))
)

fit_par_values = findmode_result.result
(a = [495.521898032129, 999.2969156044651], mu = [-0.9663909592427264, 1.9766902173125014], sigma = 0.4932665099577212)

Let's plot the data and fit function given the true parameters and MCMC samples

plot(-4:0.01:4, fit_function, samples)

plot!(
    normalize(hist, mode=:density),
    color=1, linewidth=2, fillalpha=0.0,
    st = :steps, fill=false, label = "Data",
    title = "Data, True Model and Best Fit"
)

plot!(-4:0.01:4, x -> fit_function(true_par_values, x), color=4, label = "Truth")
savefig("tutorial-data-truth-bestfit.pdf")

Data, True Model and Best Fit

Fine-grained control

BAT provides fine-grained control over the MCMC algorithm options, the MCMC chain initialization, tuning/burn-in strategy and convergence testing. All option value used in the following are the default values, any or all may be omitted.

We'll sample using the The Metropolis-Hastings MCMC algorithm:

mcmcalgo = MetropolisHastings(
    weighting = RepetitionWeighting(),
    tuning = AdaptiveMHTuning()
)
MetropolisHastings{Distributions.TDist{Float64}, RepetitionWeighting{Int64}, AdaptiveMHTuning}
  proposal: Distributions.TDist{Float64}
  weighting: RepetitionWeighting{Int64} RepetitionWeighting{Int64}()
  tuning: AdaptiveMHTuning

BAT requires a counter-based random number generator (RNG), since it partitions the RNG space over the MCMC chains. This way, a single RNG seed is sufficient for all chains and results are reproducible even under parallel execution. By default, BAT uses a Philox4x RNG initialized with a random seed drawn from the system entropy pool:

using Random123
rng = Philox4x()
context = BATContext(rng = Philox4x())

By default, MetropolisHastings() uses the following options.

For Markov chain initialization:

init = MCMCChainPoolInit()
MCMCChainPoolInit
  init_tries_per_chain: IntervalSets.ClosedInterval{Int64}
  nsteps_init: Int64 1000
  initval_alg: InitFromTarget InitFromTarget()

For the MCMC burn-in procedure:

burnin = MCMCMultiCycleBurnin()
MCMCMultiCycleBurnin
  nsteps_per_cycle: Int64 10000
  max_ncycles: Int64 30
  nsteps_final: Int64 1000

For convergence testing:

convergence = BrooksGelmanConvergence()
BrooksGelmanConvergence
  threshold: Float64 1.1
  corrected: Bool false

To generate MCMC samples with explicit control over all options, use something like

samples = bat_sample(
    posterior,
    MCMCSampling(
        mcalg = mcmcalgo,
        nchains = 4,
        nsteps = 10^5,
        init = init,
        burnin = burnin,
        convergence = convergence,
        strict = true,
        store_burnin = false,
        nonzero_weights = true,
        callback = (x...) -> nothing
    ),
    context
).result
[ Info: MCMCChainPoolInit: trying to generate 4 viable MCMC chain(s).
[ Info: Selected 4 MCMC chain(s).
[ Info: Begin tuning of 4 MCMC chain(s).
[ Info: MCMC Tuning cycle 1 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 2 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 3 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 4 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 5 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 6 finished, 4 chains, 0 tuned, 0 converged.
[ Info: MCMC Tuning cycle 7 finished, 4 chains, 1 tuned, 4 converged.
[ Info: MCMC Tuning cycle 8 finished, 4 chains, 1 tuned, 4 converged.
[ Info: MCMC Tuning cycle 9 finished, 4 chains, 1 tuned, 4 converged.
[ Info: MCMC Tuning cycle 10 finished, 4 chains, 2 tuned, 4 converged.
[ Info: MCMC Tuning cycle 11 finished, 4 chains, 2 tuned, 4 converged.
[ Info: MCMC Tuning cycle 12 finished, 4 chains, 4 tuned, 4 converged.
[ Info: MCMC tuning of 4 chains successful after 12 cycle(s).
[ Info: Running post-tuning stabilization steps for 4 MCMC chain(s).

Saving result data to files

The package FileIO.jl(in conjunction with JLD2.jl) offers a convenient way to store results like posterior samples to file:

using FileIO
import JLD2
FileIO.save("results.jld2", Dict("samples" => samples))

JLD2 persists the full information (including value shapes), so you can reload exactly the same data into memory in a new Julia session via

using FileIO
import JLD2
samples = FileIO.load("results.jld2", "samples")

provided you use compatible versions of BAT and it's dependencies. Note that JLD2 is not a long-term stable file format. Also note that this functionality is provided by FileIO.jl and JLD2.jl and not part of the BAT API itself.

BAT.jl itself can write samples to standard HDF5 files in a form suitable for long-term storage (via HDF5.jl):

import HDF5
bat_write("results.h5", samples)

The resulting files have an intuitive HDF5 layout and can be read with the standard HDF5 libraries, so they are easily accessible from other programming languages as well. Not all value shape information can be preserved, though. To read BAT.jl HDF5 sample data, use

using BAT
import HDF5
samples = bat_read("results.h5").result

BAT.jl's HDF5 file format may evolve over time, but future versions of BAT.jl will be able to read HDF5 sample data written by this version of BAT.jl.


This page was generated using Literate.jl.