"""Radial velocity fitting using MCMC and MAP estimation.
This module provides the main Fitter class for fitting radial velocity data
to planetary models using various parameterisations.
"""
# fit.py
import logging
import multiprocessing as mp
import os
import warnings
from typing import Callable, Dict, Optional
# Many builds of NumPy are linked against OpenBLAS or MKL, which can use multiple threads
# This can cause problems with multiprocessing (that we use to speed up emcee)
# So we set these to only use one thread
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
import corner
import emcee
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from matplotlib.ticker import AutoLocator, AutoMinorLocator, MultipleLocator
from scipy.optimize import minimize
from tinygp import GaussianProcess, kernels
from tqdm import tqdm
import ravest.model
from ravest.gp import GPKernel
from ravest.model import _njit_kepler_rv
from ravest.param import Parameter, Parameterisation, param_key_to_latex
# Enable 64-bit precision for better numerical accuracy
jax.config.update("jax_enable_x64", True)
logging.basicConfig(level=logging.INFO)
[docs]
class Fitter:
"""Main class for fitting radial velocity data to planetary models.
Supports MCMC sampling, MAP estimation, and various parameterisations.
Handles multiple planets, trends, and jitter parameters.
"""
def __init__(self, planet_letters: list[str], parameterisation: Parameterisation) -> None:
"""Initialize the Fitter object.
Parameters
----------
planet_letters : list[str]
List of single-character planet identifiers (e.g., ['b', 'c', 'd']).
Used to distinguish parameters for different planets in the system.
parameterisation : Parameterisation
The orbital parameterisation to use for fitting. Defines which orbital
elements are used as free/fixed parameters (e.g., 'Default', 'EccentricityWind').
"""
self.planet_letters = planet_letters
self.parameterisation = parameterisation
# Trigger numba JIT compilation before MCMC
_dummy_M = np.linspace(0, 2 * np.pi, 10)
_njit_kepler_rv(_dummy_M, 0.3, 10.0, 0.5)
# Initialize parameter storage
self._params: Dict[str, Parameter] = {}
self._priors: Dict[str, Callable[[float], float]] = {}
[docs]
def add_data(
self,
time: np.ndarray,
vel: np.ndarray,
velerr: np.ndarray,
instrument: np.ndarray,
t0: float,
) -> None:
"""Add the data to the Fitter object.
Parameters
----------
time : array-like
Time of each observation [days]
vel : array-like
Radial velocity at each time [m/s]
velerr : array-like
Uncertainty on the radial velocity at each time [m/s]
instrument : array-like
Instrument name for each observation (e.g., "HARPS", "HIRES")
t0 : float
Reference time for the trend [days].
Recommended to set this as mean or median of input `time` array.
"""
if not (len(time) == len(vel) == len(velerr) == len(instrument)):
raise ValueError(
"Time, velocity, uncertainty, and instrument arrays must be the same length."
)
self.time = np.ascontiguousarray(time)
self.vel = np.ascontiguousarray(vel)
self.velerr = np.ascontiguousarray(velerr)
self.instrument = np.asarray(instrument)
self.unique_instruments = np.unique(self.instrument)
self.t0 = t0
@property
def params(self) -> Dict[str, Parameter]:
"""Parameters dictionary. Set via: fitter.params = param_dict."""
return self._params
@params.setter
def params(self, new_params: Dict[str, Parameter]) -> None:
"""Set parameters with a dict, checking all required params are present.
You can update all or some of the parameters at once, example:
>>> fitter.params = {"g": Parameter(1.0, "m/s"), "gd": Parameter(0.1, "m/s/d")} # only update trend parameters
>>> fitter.params = {"P_c": Parameter(5.0, "d"), "K_c": Parameter(3.5, "m/s")} # only update some of planet C parameters
Parameters
----------
new_params : dict
Dictionary of new parameter values to set.
The keys of this dictionary should match the parameter names expected
by the Fitter object: all required parameters for the
chosen parameterisation, with planet letters (not required for
Trend or jitter parameters.)
Raises
------
ValueError
If any of the required parameters are missing or invalid.
"""
# Update the current _params dict with the new entries
merged_params = dict(self._params)
merged_params.update(new_params)
# Validate the complete parameter set
self._validate_complete_params(merged_params)
# If validation passes, update the actual params
self._params.update(new_params)
# Update ndim based on new free parameters
self.ndim = len(self.free_params_values)
if self.ndim == 0:
warnings.warn(
"All parameters are fixed. MCMC methods (find_map_estimate, "
"generate_initial_walker_positions_*, run_mcmc) require at least one "
"free parameter (fixed=False).",
UserWarning,
stacklevel=2
)
@property
def priors(self) -> dict:
"""Priors dictionary. Set via: fitter.priors = prior_dict."""
return self._priors
@priors.setter
def priors(self, new_priors: dict[str, Callable[[float], float]]) -> None:
"""Set prior functions using a dict, checking all required priors are present.
Priors must be provided for all free parameters. You can set all priors
at once or update individual priors.
Parameters
----------
new_priors : dict
Dictionary of prior functions to set. Keys should be parameter names
that match free parameters, values should be callable prior functions.
Examples
--------
>>> from ravest.prior import Uniform
>>> fitter.priors = {"K_b": Uniform(0, 100), "P_b": Uniform(1, 30)}
Raises
------
ValueError
If any required priors are missing, unexpected priors are provided,
or initial parameter values are outside prior bounds.
"""
self._set_priors_with_validation(new_priors)
[docs]
def _validate_complete_params(self, params: Dict[str, Parameter]) -> None:
"""Validate that params dict has required parameters, astrophysically valid values."""
# Require add_data() to have been called first (need unique_instruments)
if not hasattr(self, "unique_instruments"):
raise RuntimeError(
"add_data() must be called before setting params "
"(need instrument list for per-instrument parameters)"
)
# Build complete set of expected parameters
expected_params = set()
# Add planetary parameters
for planet_letter in self.planet_letters:
for par_name in self.parameterisation.pars:
expected_params.add(f"{par_name}_{planet_letter}")
# Add trend parameters (system-wide, no gamma offset here)
expected_params.update(["gd", "gdd"])
# Add per-instrument gamma offset and jitter parameters
for inst in self.unique_instruments:
expected_params.add(f"g_{inst}")
expected_params.add(f"jit_{inst}")
# Convert to sets for easy comparison
provided_params = set(params.keys())
# Check for unexpected parameters
unexpected_params = provided_params - expected_params
if unexpected_params:
# Give a specific hint if user is passing legacy single-instrument parameters
legacy_params = unexpected_params & {"g", "jit"}
if legacy_params:
raise ValueError(
f"Unexpected parameters: {unexpected_params}. "
f"Single-instrument 'g' and 'jit' parameters are no longer supported. "
f"Use per-instrument names instead, e.g. "
f"{[f'g_{inst}' for inst in self.unique_instruments]} and "
f"{[f'jit_{inst}' for inst in self.unique_instruments]}, "
f"matching the instrument names passed to add_data()."
)
raise ValueError(
f"Unexpected parameters: {unexpected_params}. "
f"Expected {len(expected_params)} parameters, got {len(provided_params)}"
)
# Check for missing parameters
missing_params = expected_params - provided_params
if missing_params:
raise ValueError(
f"Missing required parameters: {missing_params}. "
f"Expected {len(expected_params)} parameters, got {len(provided_params)}"
)
# Validate astrophysical validity of all parameters
params_values = {name: param.value for name, param in params.items()}
self._validate_astrophysical_validity(params_values)
# Validate parameter coupling constraints
# i.e. if two parameters both need to be fixed or free together
self._validate_parameter_coupling(params)
[docs]
def _validate_astrophysical_validity(self, params_values: Dict[str, float]) -> None:
"""Validate that all parameter values are astrophysically valid."""
# First, check that ALL parameters are finite (not NaN or infinite)
invalid_params = { name: value for name, value in params_values.items() if not np.isfinite(value) }
if invalid_params:
raise ValueError( "Invalid parameters detected: " + ", ".join(f"{k}={v}" for k, v in invalid_params.items()) )
# Validate planetary parameters for each planet
for planet_letter in self.planet_letters:
planet_params = {}
for par_name in self.parameterisation.pars:
key = f"{par_name}_{planet_letter}"
planet_params[par_name] = params_values[key]
# Validate this planet's parameters in current parameterisation
self.parameterisation.validate_planetary_params(planet_params)
# Validate trend parameters are finite real numbers (already checked above, but kept for clarity)
for trend_param in ["gd", "gdd"]:
trend_value = params_values[trend_param]
if not np.isfinite(trend_value):
raise ValueError(f"Invalid trend parameter {trend_param}: {trend_value} is not a finite real number")
# Validate per-instrument parameters
for inst in self.unique_instruments:
# Gamma offset must be finite
g_key = f"g_{inst}"
if not np.isfinite(params_values[g_key]):
raise ValueError(f"Invalid gamma offset {g_key}: {params_values[g_key]} is not finite")
# Jitter must be >= 0
jit_key = f"jit_{inst}"
if params_values[jit_key] < 0:
raise ValueError(f"Invalid jitter {jit_key}: {params_values[jit_key]} < 0")
[docs]
def _validate_parameter_coupling(self, params: Dict[str, Parameter]) -> None:
"""Validate parameter coupling constraints (e.g., secosw/sesinw must both be free or both fixed)."""
for planet_letter in self.planet_letters:
# Check secosw/sesinw coupling
secosw_key = f"secosw_{planet_letter}"
sesinw_key = f"sesinw_{planet_letter}"
if secosw_key in params and sesinw_key in params:
secosw_fixed = params[secosw_key].fixed
sesinw_fixed = params[sesinw_key].fixed
if secosw_fixed != sesinw_fixed:
raise ValueError(f"Parameters {secosw_key} and {sesinw_key} must both be fixed or both be free")
# Check ecosw/esinw coupling
ecosw_key = f"ecosw_{planet_letter}"
esinw_key = f"esinw_{planet_letter}"
if ecosw_key in params and esinw_key in params:
ecosw_fixed = params[ecosw_key].fixed
esinw_fixed = params[esinw_key].fixed
if ecosw_fixed != esinw_fixed:
raise ValueError(f"Parameters {ecosw_key} and {esinw_key} must both be fixed or both be free")
[docs]
def _set_priors_with_validation(self, new_priors: dict[str, Callable[[float], float]]) -> None:
"""Set priors with validation. Supports partial updates. Can be current or default parameterisation."""
# Create merged priors dict (in case user is only updating some priors, not all)
merged_priors_dict = dict(self._priors) # get existing priors
merged_priors_dict.update(new_priors) # overwrite with newer functions, if supplied
provided_prior_param_names = set(merged_priors_dict.keys())
# There are two possibilities for priors:
# 1. The prior has been given for the parameter, in the current parameterisation
# (this can also include if the user is fitting in the default parameterisation)
# 2. The prior has been given for the Default parameterisation's equivalent parameter instead
# (e.g. e & w instead of secosw & sesinw, or Tp instead of Tc)
# If not, then prior isn't given for either the Current or Default parameterisation, raise an Exception
validated_priors = {}
missing_priors = []
conflicts = []
# in the current parameterisation, which (free) parameters do we expect priors for?
current_parameterisation_free_param_names = set(self.free_params_names)
for free_param_name in current_parameterisation_free_param_names:
if free_param_name in provided_prior_param_names:
# Prior was provided for the param in the current parameterisation
validated_priors[free_param_name] = merged_priors_dict[free_param_name]
# Check if user ALSO provided equivalent default priors (conflict!)
default_parameterisation_equivalent_free_param_names = self._get_default_parameterisation_equivalent_free_param_name(free_param_name)
if default_parameterisation_equivalent_free_param_names:
for equiv_param in default_parameterisation_equivalent_free_param_names:
if equiv_param in provided_prior_param_names:
conflicts.append((free_param_name, equiv_param))
else:
# We haven't been provided the prior for the free parameter in the current parameterisation
# So let's check if we were given the prior for the equivalent parameter in the default parameterisation instead
default_parameterisation_equivalent_free_param_names = self._get_default_parameterisation_equivalent_free_param_name(free_param_name)
# remember that one parameter in current parameterisation (e.g. secosw) might map to more than one equivalent in default parameterisation (e.g. both e & w)
if default_parameterisation_equivalent_free_param_names and all(eq in provided_prior_param_names for eq in default_parameterisation_equivalent_free_param_names):
# Found all required default equivalents
for equiv in default_parameterisation_equivalent_free_param_names:
validated_priors[equiv] = merged_priors_dict[equiv]
else:
# Missing prior for a free parameter in both the current parameterisation, and its equivalent in the default parameterisation
if default_parameterisation_equivalent_free_param_names:
missing_priors.append(f"{free_param_name} (or equivalent {default_parameterisation_equivalent_free_param_names})")
else:
missing_priors.append(free_param_name)
# Check for conflicts after processing all parameters
if conflicts:
conflict_strs = [f"{current} vs {default}" for current, default in conflicts]
raise ValueError(f"Conflicting priors provided for both current and default parameterisations: {', '.join(conflict_strs)}. Please provide priors for either the current parameterisation OR the equivalent default parameterisation, but not both.")
if missing_priors:
raise ValueError(f"Missing priors for parameters: {missing_priors}")
# Check for unexpected priors - only allow priors that were validated above
expected_prior_param_names = set(validated_priors.keys())
unexpected_prior_param_names = provided_prior_param_names - expected_prior_param_names
if unexpected_prior_param_names:
raise ValueError(
f"Unexpected priors supplied for parameters: {unexpected_prior_param_names}. "
f"Priors expected only for parameters: {expected_prior_param_names}"
)
# Check parameter values work with priors
self._check_params_values_against_priors(validated_priors, current_parameterisation_free_param_names)
# Update the priors with the new values
self._priors.update(new_priors)
[docs]
def _get_default_parameterisation_equivalent_free_param_name(self, free_param: str) -> Optional[list[str]]:
"""Get the names of the default parameterisation equivalent parameter(s), for a single free parameter from the current parameterisation.
Note this can be more than one: e.g. if you have secosw, this affects both e & w in the default parameterisation
Whereas Tc just maps to Tp alone.
Returns
-------
list[str] | None
- list[str]: equivalent parameter(s) in the default parameterisation
- None: no mapping needed / no alternative priors to look for
Raises
------
ValueError
If `free_param` is not a recognised planet, instrument, or trend parameter.
"""
# No underscore (expected to be a system trend parameter)
if '_' not in free_param:
if free_param in ['gd', 'gdd']:
# These are the same in all parameterisations
return None
else:
raise ValueError(f"Unknown free parameter: {free_param}")
# Contains underscore: Planetary or instrument parameters (with underscore before either planet letter or instrument name)
# e.g. P_b, or Tc_c, or jit_HARPS
else:
base_param, suffix = free_param.split('_', 1) # split only on first underscore (some instrument names may have underscores too)
# Planetary parameters: suffix is a planet letter
if suffix in self.planet_letters:
planet_letter = suffix
if base_param in ['secosw', 'sesinw']:
# Both secosw and sesinw map to e,w equivalents
partner_param = 'sesinw' if base_param == 'secosw' else 'secosw'
partner_key = f"{partner_param}_{planet_letter}"
if partner_key in self.free_params_names:
return [f"e_{planet_letter}", f"w_{planet_letter}"]
elif base_param in ['ecosw', 'esinw']:
# Both ecosw and esinw map to e,w equivalents
partner_param = 'esinw' if base_param == 'ecosw' else 'ecosw'
partner_key = f"{partner_param}_{planet_letter}"
if partner_key in self.free_params_names:
return [f"e_{planet_letter}", f"w_{planet_letter}"]
elif base_param == 'Tc':
# Tc can use Tp equivalent
return [f"Tp_{planet_letter}"]
elif base_param in ['P', 'K', 'e', 'w', 'Tp']:
# These are default parameterisation parameters anyway
return None
else:
# Suffix is a valid planet letter, but base parameter is unrecognised, so raise an error
raise ValueError(f"Free parameter {free_param} has known planet letter {planet_letter} but unrecognised base parameter {base_param}.")
# Instrument parameters: suffix is an instrument name
elif suffix in self.unique_instruments:
# The only instrument parameters are g and jit
if base_param in ['g', 'jit']:
# Per-instrument parameter (e.g., g_HARPS, jit_HIRES)
# These are the same in all parameterisations
return None
else:
raise ValueError(f"Free parameter {free_param} has known instrument name {suffix} but unrecognised base parameter {base_param} (expected 'g' or 'jit' only)")
# Unknown: Suffix is present, but not a planet letter or instrument, so raise an error
else:
raise ValueError(f"Free parameter {free_param} has unrecognised suffix {suffix}, expected one of planet letters {self.planet_letters} or instrument names {self.unique_instruments}.")
[docs]
def _check_params_values_against_priors(self, validated_priors: dict[str, Callable[[float], float]], current_free_param_names: list[str]) -> None:
"""Check parameter values against priors (including if Prior is for the Default parameterisation equivalent parameter)."""
for prior_param_name, prior_function in validated_priors.items():
if prior_param_name in current_free_param_names:
# This prior is in current parameterisation - check directly
param_value = self.params[prior_param_name].value
log_prior_probability = prior_function(param_value)
if not np.isfinite(log_prior_probability):
raise ValueError(f"Initial value {param_value} of parameter {prior_param_name} is invalid for prior {prior_function}.")
else:
# This prior is in default parameterisation - need to convert parameter value
# Get the current parameter value and convert to default
default_param_value = self._convert_single_param_to_default(prior_param_name)
log_prior_probability = prior_function(default_param_value)
if not np.isfinite(log_prior_probability):
raise ValueError(f"Initial value {default_param_value} of parameter {prior_param_name} (in default parameterisation) is invalid for prior {prior_function}.")
[docs]
def _convert_single_param_to_default(self, default_param_name: str) -> float:
"""Convert a single parameter from current to default parameterisation."""
# Extract planet letter if this is a planetary parameter
if '_' in default_param_name:
base_param, planet_letter = default_param_name.rsplit('_', 1)
if planet_letter in self.planet_letters:
# Get all current parameters for this planet (we need all five parameters to do a conversion)
planet_params_dict = {}
for par_name in self.parameterisation.pars:
param_key = f"{par_name}_{planet_letter}"
planet_params_dict[par_name] = self.params[param_key].value
# Convert all the planetary parameters to the default parameterisation
default_planet_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params_dict)
# Return just the requested parameter in the default parameterisation
return default_planet_params[base_param]
# For non-planetary parameters (g, gd, gdd, jit), they're the same in all parameterisations
if default_param_name in self.params:
return self.params[default_param_name].value
raise ValueError(f"Cannot convert parameter {default_param_name} to default parameterisation")
@property
def free_params_dict(self) -> Dict[str, Parameter]:
"""Free parameters as dict."""
free_pars = {}
for par in self.params:
if self.params[par].fixed is False:
free_pars[par] = self.params[par]
return free_pars
@property
def free_params_values(self) -> list[float]:
"""Values of free parameters as list."""
return [param.value for param in self.free_params_dict.values()]
@property
def free_params_names(self) -> list[str]:
"""Names of free parameters as list."""
return list(self.free_params_dict.keys())
@property
def fixed_params_dict(self) -> Dict[str, Parameter]:
"""Fixed parameters as dict, mapping names to Parameter objects."""
fixed_pars = {}
for par in self.params:
if self.params[par].fixed is True:
fixed_pars[par] = self.params[par]
return fixed_pars
@property
def fixed_params_values(self) -> list[float]:
"""Values of fixed parameters, as list."""
return [param.value for param in self.fixed_params_dict.values()]
@property
def fixed_params_names(self) -> list[str]:
"""Names of fixed parameters, as list."""
return list(self.fixed_params_dict.keys())
@property
def fixed_params_values_dict(self) -> Dict[str, float]:
"""Fixed parameters as dict mapping names to just the values."""
return dict(zip(self.fixed_params_names, self.fixed_params_values))
[docs]
def find_map_estimate(self, method: str = "Powell") -> scipy.optimize.OptimizeResult:
"""Find Maximum A Posteriori (MAP) estimate of parameters.
Parameters
----------
method : str, optional
Optimization method to use (default: "Powell")
Returns
-------
scipy.optimize.OptimizeResult
The optimization result containing the MAP estimate
Raises
------
Warning
If MAP optimization fails to converge
"""
# Initialize log-posterior object
lp = LogPosterior(
self.planet_letters,
self.parameterisation,
self.priors,
self.fixed_params_values_dict,
self.free_params_names,
self.time,
self.vel,
self.velerr,
self.instrument,
self.unique_instruments,
self.t0,
)
initial_guess = self.free_params_values
if len(initial_guess) == 0:
raise ValueError(
"Cannot run MAP optimisation: no free parameters to optimise. "
"At least one parameter must be set as free (fixed=False) before calling find_map_estimate()."
)
# Perform MAP optimization
def negative_log_posterior(*args: float) -> float:
return lp._negative_log_probability_for_MAP(*args)
map_results = minimize(negative_log_posterior, initial_guess, method=method)
if map_results.success is False:
print(map_results)
warnings.warn("MAP did not succeed. Check the initial values of the parameters, and the prior functions.")
# Print results as dictionary (to show param names too)
map_results_dict = dict(zip(self.free_params_names, map_results.x))
print("MAP parameter results:", map_results_dict)
# Return the scipy OptimizeResult object so that user can inspect fully if needed
return map_results
[docs]
def generate_initial_walker_positions_random(self, nwalkers: int, verbose: bool = False, max_attempts: int = 1000) -> np.ndarray:
"""Generate random initial walker positions that satisfy priors and are astrophysically valid.
Creates random starting positions for MCMC walkers by sampling from
appropriate distributions based on each parameter's prior type. Ensures
that parameter combinations are astrophysically valid (e.g., eccentricity < 1).
Parameters
----------
nwalkers : int
Number of MCMC walkers to generate positions for
verbose : bool, default False
If True, print walker positions during generation
max_attempts : int, default 1000
Maximum attempts to generate a valid walker position
Returns
-------
np.ndarray
Array of shape (nwalkers, ndim) where ndim is the number of free parameters.
Each row represents the starting position for one walker in the order of
free_params_names.
Raises
------
ValueError
If a prior type is not supported for walker generation or if unable
to generate valid positions after max_attempts
Examples
--------
>>> # Generate positions for 40 walkers
>>> nwalkers = 10 * len(fitter.free_params_names)
>>> initial_positions = fitter.generate_initial_walker_positions_random(nwalkers)
>>> fitter.run_mcmc(initial_positions, nwalkers, max_steps=2000)
"""
if len(self.free_params_values) == 0:
raise ValueError(
"Cannot generate walker positions: no free parameters to sample. "
"At least one parameter must be set as free (fixed=False)."
)
if verbose:
print("Free parameters:", self.free_params_names)
mcmc_init = []
for walker_idx in range(nwalkers):
attempts = 0
while attempts < max_attempts:
walker_position = []
for param_name in self.free_params_names:
# Check if we have a direct prior for this parameter
# (because user may be fitting in a transformed parameterisation, but gave priors in the default parameterisation instead)
if param_name in self.priors:
prior = self.priors[param_name]
if isinstance(prior, ravest.prior.Normal):
walker_position.append(np.random.normal(loc=prior.mean, scale=2*prior.std))
elif isinstance(prior, ravest.prior.HalfNormal):
walker_position.append(np.abs(np.random.normal(loc=0, scale=2*prior.std)))
elif isinstance(prior, ravest.prior.Uniform):
walker_position.append(np.random.uniform(low=prior.lower, high=prior.upper))
elif isinstance(prior, ravest.prior.TruncatedNormal):
walker_position.append(np.random.uniform(low=prior.lower, high=prior.upper))
elif isinstance(prior, ravest.prior.Beta):
walker_position.append(np.random.uniform(low=0, high=1))
elif isinstance(prior, ravest.prior.EccentricityUniform):
walker_position.append(np.random.uniform(low=0, high=prior.upper))
else:
raise ValueError(f"Unsupported prior type for walker generation: {type(prior)}")
else:
# No direct prior for this parameter (this happens if fitting in a transformed parameterisation, but prior is in Default)
# Instead use current value + small perturbation
centre_val = self.params[param_name].value
# Add small random perturbation (10% of current value + small fixed amount for near-zero values)
perturbation = np.random.normal(0, abs(centre_val) * 0.1 + 0.01)
walker_position.append(centre_val + perturbation)
# Check astrophysical validity and prior compliance
try:
# Convert walker position to full parameter dict (free + fixed)
free_params_dict = dict(zip(self.free_params_names, walker_position))
all_params_dict = self.fixed_params_values_dict | free_params_dict
# Check astrophysical validity
self._validate_astrophysical_validity(all_params_dict)
# Check prior compliance using LogPosterior, rather than calling priors direct
# (because it handles Transformed->Default parameter transformations already, if needed)
lp = LogPosterior(
self.planet_letters,
self.parameterisation,
self.priors,
self.fixed_params_values_dict,
self.free_params_names,
self.time,
self.vel,
self.velerr,
self.instrument,
self.unique_instruments,
self.t0,
)
# Check the log-prior probability is finite (i.e. proposed initial values are within prior bounds)
params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict)
log_prior = lp.log_prior(params_for_prior)
if not np.isfinite(log_prior):
raise ValueError(f"Outside prior bounds (log_prior = {log_prior})")
# If both astrophysical and priors validations pass, we have a valid walker position
break
except ValueError:
# Validation failed. Generate a new set of values and try again.
attempts += 1
continue
if attempts >= max_attempts:
raise ValueError(f"Could not generate astrophysically valid walker {walker_idx} after {max_attempts} attempts. "
f"Consider relaxing priors or checking parameter constraints.")
if verbose:
print(f"Walker {walker_idx} position: {walker_position} (valid after {attempts + 1} attempts)")
mcmc_init.append(walker_position)
mcmc_init = np.array(mcmc_init)
if verbose:
print(f"Generated MCMC initial positions with shape: {mcmc_init.shape}")
return mcmc_init
[docs]
def generate_initial_walker_positions_around_point(
self,
centre: np.ndarray | list,
nwalkers: int,
scale: float = 1e-4,
relative: bool = True,
verbose: bool = False,
max_attempts: int = 1000
) -> np.ndarray:
"""Generate initial walker positions in a ball around a supplied centre point.
Creates starting positions for MCMC walkers clustered around a centre point
(e.g., MAP estimate). Each walker is generated by adding small random perturbations
to the centre values. Validates that both the centre point and all generated
walker positions satisfy priors and are astrophysically valid.
Parameters
----------
centre : np.ndarray or list
Centre point for walker positions. Must have length equal to the number
of free parameters and be in the order of free_params_names.
nwalkers : int
Number of MCMC walkers to generate positions for
scale : float, default 1e-4
Scale of perturbations around centre point
relative : bool, default True
If True, perturbations scale with parameter values (scale * centre * random).
If False, perturbations are absolute (scale * random).
verbose : bool, default False
If True, print walker positions during generation
max_attempts : int, default 1000
Maximum attempts to generate a valid walker position
Returns
-------
np.ndarray
Array of shape (nwalkers, ndim) where ndim is the number of free parameters.
Each row represents the starting position for one walker in the order of
free_params_names.
Raises
------
ValueError
If centre has wrong length, if centre point is invalid, or if unable
to generate valid positions after max_attempts
Examples
--------
>>> # Generate walkers around MAP estimate
>>> map_result = fitter.find_map_estimate()
>>> initial_positions = fitter.generate_initial_walker_positions_around_point(
... centre=map_result.x, nwalkers=40, scale=1e-4
... )
>>> fitter.run_mcmc(initial_positions, nwalkers=40, max_steps=2000)
"""
if len(self.free_params_values) == 0:
raise ValueError(
"Cannot generate walker positions: no free parameters to sample. "
"At least one parameter must be set as free (fixed=False)."
)
centre = np.asarray(centre)
if len(centre) != len(self.free_params_names):
raise ValueError(
f"Centre must have length {len(self.free_params_names)} "
f"(number of free parameters), got {len(centre)}"
)
if verbose:
print("Free parameters:", self.free_params_names)
print(f"Centre values: {centre}")
# Validate centre point first
try:
free_params_dict = dict(zip(self.free_params_names, centre))
all_params_dict = self.fixed_params_values_dict | free_params_dict
# Check astrophysical validity
self._validate_astrophysical_validity(all_params_dict)
# Check prior compliance
lp = LogPosterior(
self.planet_letters,
self.parameterisation,
self.priors,
self.fixed_params_values_dict,
self.free_params_names,
self.time,
self.vel,
self.velerr,
self.instrument,
self.unique_instruments,
self.t0,
)
params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict)
log_prior = lp.log_prior(params_for_prior)
if not np.isfinite(log_prior):
raise ValueError(f"Centre point outside prior bounds (log_prior = {log_prior})")
if verbose:
print(f"Centre point validated (log_prior = {log_prior})")
except ValueError as e:
raise ValueError(f"Supplied centre point is not valid: {e}")
# Generate walker positions around centre
mcmc_init = []
if verbose and relative and np.any(centre == 0.0):
zero_names = [self.free_params_names[i] for i in range(len(centre)) if centre[i] == 0.0]
print(f"Note: centre value is exactly 0.0 for {zero_names}; "
f"using absolute perturbation (scale={scale}) for these parameters.")
for walker_idx in range(nwalkers):
attempts = 0
while attempts < max_attempts:
# Generate perturbation
random_vals = np.random.randn(len(centre))
if relative:
# Relative perturbation: scales with parameter values.
# When a centre value is exactly 0.0, the relative
# perturbation (scale * randn * |0|) is always zero,
# producing identical walker values in that dimension.
# This causes emcee to reject the walkers as linearly
# dependent (condition number check). Fall back to
# absolute perturbation for those parameters.
perturbation = np.empty(len(centre))
for i in range(len(centre)):
if centre[i] == 0.0:
perturbation[i] = scale * random_vals[i]
else:
perturbation[i] = scale * random_vals[i] * np.abs(centre[i])
else:
# Absolute perturbation: same scale for all parameters
perturbation = scale * random_vals
walker_position = centre + perturbation
# Validate this walker position
try:
free_params_dict = dict(zip(self.free_params_names, walker_position))
all_params_dict = self.fixed_params_values_dict | free_params_dict
# Check astrophysical validity
self._validate_astrophysical_validity(all_params_dict)
# Check prior compliance
params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict)
log_prior = lp.log_prior(params_for_prior)
if not np.isfinite(log_prior):
raise ValueError(f"Outside prior bounds (log_prior = {log_prior})")
# If validation passes, we have a valid walker position
break
except ValueError:
# Validation failed, try again
attempts += 1
continue
if attempts >= max_attempts:
raise ValueError(
f"Could not generate astrophysically valid walker {walker_idx} after {max_attempts} attempts. "
f"Consider using a larger scale parameter or checking that the centre point is not too close to prior/physical boundaries."
)
if verbose:
print(f"Walker {walker_idx} position: {walker_position} (valid after {attempts + 1} attempts)")
mcmc_init.append(walker_position)
mcmc_init = np.array(mcmc_init)
if verbose:
print(f"Generated MCMC initial positions with shape: {mcmc_init.shape}")
return mcmc_init
[docs]
def generate_initial_walker_positions_from_map(
self,
map_result: scipy.optimize.OptimizeResult,
nwalkers: int,
scale: float = 1e-4,
relative: bool = True,
verbose: bool = False,
max_attempts: int = 1000
) -> np.ndarray:
"""Generate initial walker positions around MAP estimate.
Convenience function that generates walker positions clustered around
MAP parameter estimates from a pre-computed MAP result.
Parameters
----------
map_result : scipy.optimize.OptimizeResult
Result from find_map_estimate()
nwalkers : int
Number of MCMC walkers to generate positions for
scale : float, default 1e-4
Scale of perturbations around MAP values
relative : bool, default True
If True, perturbations scale with parameter values.
If False, perturbations are absolute.
verbose : bool, default False
If True, print walker positions during generation
max_attempts : int, default 1000
Maximum attempts to generate a valid walker position
Returns
-------
np.ndarray
Array of shape (nwalkers, ndim) where ndim is the number of free parameters.
Each row represents the starting position for one walker.
Raises
------
ValueError
If unable to generate valid positions
Examples
--------
>>> # Find MAP then generate walkers around it
>>> map_result = fitter.find_map_estimate()
>>> initial_positions = fitter.generate_initial_walker_positions_from_map(
... map_result=map_result, nwalkers=40
... )
>>> fitter.run_mcmc(initial_positions, nwalkers=40, max_steps=2000)
"""
if len(self.free_params_values) == 0:
raise ValueError(
"Cannot generate walker positions: no free parameters to sample. "
"At least one parameter must be set as free (fixed=False)."
)
return self.generate_initial_walker_positions_around_point(
centre=map_result.x,
nwalkers=nwalkers,
scale=scale,
relative=relative,
verbose=verbose,
max_attempts=max_attempts
)
[docs]
def run_mcmc(self, initial_positions : np.ndarray, nwalkers: int, max_steps: int = 5000, progress: bool = True, multiprocessing: bool = False, check_convergence: bool = False, convergence_check_interval: int = 1000, convergence_check_start: int = 0) -> None:
"""Run MCMC sampling from given initial walker positions.
Parameters
----------
initial_positions : np.ndarray
Starting positions for all MCMC walkers. Shape must be (nwalkers, ndim)
where ndim is the number of free parameters. Each row represents the
starting position for one walker in the order of free_params_names.
nwalkers : int
Number of MCMC walkers (must match first dimension of initial_positions )
max_steps : int, optional
Maximum number of MCMC steps to run. If check_convergence=False, runs for
exactly this many steps. If check_convergence=True, runs up to this many
steps, stopping early when convergence criteria are met (default: 5000)
progress : bool, optional
Whether to show progress bar during MCMC (default: True)
multiprocessing : bool, optional
Whether to use multiprocessing for MCMC (default: False)
check_convergence : bool, optional
If True, check for convergence and stop early when criteria met.
Convergence requires: chain length > 50 times max autocorrelation time,
and autocorrelation time estimate stable to 1 percent.
If False, run for exactly max_steps (default: False)
convergence_check_interval : int, optional
Steps between convergence checks (only used if check_convergence=True) (default: 1000)
convergence_check_start : int, optional
Minimum iteration before starting convergence checks. Set this sensibly
(e.g. 2x burn-in) to avoid inaccurate tau estimates on short chains (default: 0)
"""
if len(self.free_params_values) == 0:
raise ValueError(
"Cannot run MCMC: no free parameters to sample. "
"At least one parameter must be set as free (fixed=False)."
)
# Initialize log-posterior object for MCMC sampling
lp = LogPosterior(
self.planet_letters,
self.parameterisation,
self.priors,
self.fixed_params_values_dict,
self.free_params_names,
self.time,
self.vel,
self.velerr,
self.instrument,
self.unique_instruments,
self.t0,
)
# Enforce minimum number of walkers (though users ideally should have many more than this)
if nwalkers < 2 * self.ndim:
logging.warning(f"nwalkers should be at least 2 * ndim. You have {nwalkers} walkers and {self.ndim} dimensions. Setting nwalkers to {2 * self.ndim}.")
self.nwalkers = 2 * self.ndim
else:
self.nwalkers = nwalkers
# Validate walker positions shape
if initial_positions .shape != (nwalkers, self.ndim):
raise ValueError(f"initial_positions must have shape ({nwalkers}, {self.ndim}), got {initial_positions .shape}")
# Validate every walker position for astrophysical validity and prior compliance
# (we don't want to start any chains in invalid parameter space)
for i, walker_position in enumerate(initial_positions ):
walker_params_dict = dict(zip(self.free_params_names, walker_position))
all_params_dict = self.fixed_params_values_dict | walker_params_dict
# Check astrophysical validity
try:
self._validate_astrophysical_validity(all_params_dict)
except ValueError as e:
raise ValueError(f"Walker {i} has invalid astrophysical parameters: {e}") from e
# Check prior compliance
params_for_prior = lp._convert_params_for_prior_evaluation(walker_params_dict)
log_prior = lp.log_prior(params_for_prior)
if not np.isfinite(log_prior):
raise ValueError(f"Walker {i} is outside prior bounds (log_prior = {log_prior})")
# TODO: parameter_names argument does slightly impact performance - but not sure if it can be avoided, we do need the names
# and I'm not sure constructing the dictionary later ourselves manually is any quicker than passing parameter_names argument
# Create sampler
if multiprocessing:
pool = mp.get_context("spawn").Pool() # Use 'spawn' instead of 'fork' to avoid issues on some Linux platforms
sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, lp.log_probability,
parameter_names=self.free_params_names,
pool=pool)
else:
sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, lp.log_probability,
parameter_names=self.free_params_names)
# Warn if convergence arguments provided but convergence checking disabled
if not check_convergence:
if convergence_check_interval != 1000 or convergence_check_start != 0:
logging.warning(
"Convergence checking arguments provided but check_convergence=False. "
"These arguments will be ignored. Did you forget to set check_convergence=True?"
)
# Run MCMC with or without convergence checking
if not check_convergence:
# Fixed-length mode - run for exactly max_steps
logging.info(f"Starting MCMC for {max_steps} steps...")
sampler.run_mcmc(initial_state=initial_positions, nsteps=max_steps, progress=progress)
logging.info("...MCMC done.")
else:
# Convergence checking - run up to max_steps, stopping early if converged
logging.info(f"Starting MCMC with convergence checks. (Maximum {max_steps} steps, checking convergence every {convergence_check_interval} steps after iteration {convergence_check_start})...")
# Initialize autocorrelation history storage
self.autocorr_history = {}
old_tau = np.inf
for sample in sampler.sample(initial_state=initial_positions, iterations=max_steps, progress=progress):
# Only check at specified intervals
if sampler.iteration % convergence_check_interval != 0:
continue
# Don't check before we have reached convergence_check_start
if sampler.iteration < convergence_check_start:
continue
# Get autocorrelation time estimate
tau = sampler.get_autocorr_time(tol=0)
# Store autocorrelation history for plotting/diagnostics later
self.autocorr_history[sampler.iteration] = tau.copy()
# Log progress
logging.info(f"Convergence check: Step {sampler.iteration}: mean(tau)={np.mean(tau):.1f}, max(tau)={np.max(tau):.1f}")
# Check convergence criteria
check_chain_length = np.all(sampler.iteration > 50 * tau) # Chain length > 50 * tau
check_stable_tau = np.all(np.abs(old_tau - tau) / tau < 0.01) # Tau stable to 1 percent
converged = check_chain_length and check_stable_tau
if converged:
logging.info(f"Converged at iteration {sampler.iteration}")
break
else:
logging.info(f"Not yet converged (N/50>tau check: {check_chain_length}, tau stability check: {check_stable_tau})")
# Warn if approaching max steps without convergence
if sampler.iteration > 0.8 * max_steps:
logging.warning(f"Approaching max iterations ({max_steps}) without convergence! (max tau={np.max(tau):.1f}, tau stability change={np.abs(old_tau - tau) / tau})")
# Update old tau for next check
old_tau = tau
# Final log
final_steps = sampler.iteration
logging.info(f"MCMC complete: {final_steps} steps total")
# Close multiprocessing pool if used
if multiprocessing:
pool.close()
pool.join()
self.sampler = sampler
[docs]
def get_samples_np(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, flat: bool = False) -> np.ndarray:
"""Return a contiguous numpy array of MCMC samples.
Samples can be discarded from the start and/or the end of the array. You can
also thin (take only every n-th sample), and you can flatten the array
so that each walker's chain is merged into one chain.
This is the foundational method for accessing MCMC samples. All the other
get_samples methods build on this.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
flat : bool, optional
Whether to flatten each walker's chain into one chain. (default: False)
If True, return flattened array with shape (nsteps_after_discard_thin * nwalkers, ndim)
If False, return unflattened array with shape (nsteps_after_discard_thin, nwalkers, ndim)
Returns
-------
np.ndarray
Contiguous array of MCMC samples. Shape depends on `flat` parameter:
- flat=False: (nsteps_after_discard_thin, nwalkers, ndim)
- flat=True: (nsteps_after_discard_thin * nwalkers, ndim)
Notes
-----
We enforce np.ascontiguousarray() on the return, because np.reshape() does
not guarantee a contiguous array in memory.
"""
# Get the full chain from emcee without any processing
full_samples = self.sampler.get_chain(discard=0, thin=1, flat=False)
# Match emcee's slicing logic: [discard + thin - 1 : end : thin]
# But adapted - we also allow for discarding from the end
start_idx = discard_start + thin - 1
if discard_end == 0:
end_idx = full_samples.shape[0]
else:
end_idx = full_samples.shape[0] - discard_end
# Check the start and end points are valid
if start_idx >= end_idx:
raise ValueError(f"Invalid parameters: start_idx ({start_idx}) >= end_idx ({end_idx}). "
f"Try reducing discard_start ({discard_start}), discard_end ({discard_end}), or thin ({thin}).")
# Apply the slicing
samples = full_samples[start_idx:end_idx:thin]
# Flatten if requested (after discarding) - flatten steps and walkers into single dimension
if flat:
# (steps, walkers, ndim) -> (steps*walkers, ndim)
nsteps, nwalkers, ndim = samples.shape
samples = samples.reshape(nsteps * nwalkers, ndim)
return np.ascontiguousarray(samples)
[docs]
def get_samples_df(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> pd.DataFrame:
"""Return a pandas DataFrame of flattened MCMC samples.
Each row represents one sample, each column represents one parameter.
Built on get_samples_np().
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
Returns
-------
pd.DataFrame
DataFrame with shape (nsteps_after_discard_thin * nwalkers, ndim).
Columns are parameter names.
"""
flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
return pd.DataFrame(flat_samples, columns=self.free_params_names)
[docs]
def get_samples_dict(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> Dict[str, np.ndarray]:
"""Return a dict of flattened MCMC samples.
Each parameter gets a 1D (flattened) contiguous array of all its samples.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
Returns
-------
dict
Dictionary mapping parameter names to 1D arrays of samples.
Each array has shape (nsteps_after_discard_thin * nwalkers,)
Examples
--------
>>> samples_dict = fitter.get_samples_dict(discard_start=1000)
>>> K_b_samples = samples_dict['K_b'] # All samples for parameter K for planet b
"""
flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
param_names = self.free_params_names
# Direct numpy slicing - much faster than pandas operations
return {name: flat_samples[:, i] for i, name in enumerate(param_names)}
[docs]
def get_sampler_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, flat: bool = False) -> np.ndarray:
"""Returns the log probability at each step of the sampler.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
flat : bool, optional
If True, return flattened array shape (nsteps_after_discard_thin * nwalkers)
If False, return unflattened array shape (nsteps_after_discard_thin, nwalkers) (default: False)
Returns
-------
np.ndarray
Array of log probabilities of the function at each sample.
"""
# Get the full log prob chain from emcee without any processing
full_lnprob = self.sampler.get_log_prob(discard=0, thin=1, flat=False)
# Match emcee's slicing logic: [discard + thin - 1 : end : thin]
# But adapted - we also allow for discarding from the end
start_idx = discard_start + thin - 1
if discard_end == 0:
end_idx = full_lnprob.shape[0]
else:
end_idx = full_lnprob.shape[0] - discard_end
# Check the start and end points are valid
if start_idx >= end_idx:
raise ValueError(f"Invalid parameters: start_idx ({start_idx}) >= end_idx ({end_idx}). "
f"Try reducing discard_start ({discard_start}), discard_end ({discard_end}), or thin ({thin}).")
# Apply the slicing
lnprob = full_lnprob[start_idx:end_idx:thin]
# Flatten if requested (after discarding) - flatten steps and walkers into single dimension
if flat:
# (steps, walkers) -> (steps*walkers,)
nsteps, nwalkers = lnprob.shape
lnprob = lnprob.reshape(nsteps * nwalkers)
return np.ascontiguousarray(lnprob)
[docs]
def get_mcmc_posterior_dict(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> dict:
"""Return dict combining MCMC samples for free params, and the fixed values for the fixed params.
This method creates a unified dictionary containing all model parameters:
fixed parameters as single float values, and free parameters as arrays
of MCMC samples. This format is ideal for functions like calculate_mpsini
that need all parameters (whether free or fixed), and that should propagate
uncertainties from the free parameters samples.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
Returns
-------
dict
Dictionary of all parameters:
- Fixed parameters: single float values
- Free parameters: 1D arrays of MCMC samples with shape (nsteps_after_discard_thin * nwalkers,)
"""
fixed_params_dict = self.fixed_params_values_dict
free_samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin)
return fixed_params_dict | free_samples_dict
[docs]
def calculate_log_likelihood(self, params_dict: Dict[str, float]) -> float:
"""Calculate log-likelihood for given parameter values.
Note this does not include (log-)prior probabilities, this is just the
(log-) *likelihood* primarily for use in AICc & BIC calculation.
Parameters
----------
params_dict : dict
Dictionary of all parameter values (both fixed and free parameters)
Returns
-------
float
The log-likelihood value
"""
# Create LogLikelihood object (same as in find_map_estimate and run_mcmc)
log_likelihood = LogLikelihood(
time=self.time,
vel=self.vel,
velerr=self.velerr,
instrument=self.instrument,
unique_instruments=self.unique_instruments,
t0=self.t0,
planet_letters=self.planet_letters,
parameterisation=self.parameterisation,
)
return log_likelihood(params_dict)
[docs]
def build_params_dict(self, free_params: np.ndarray | list | Dict[str, float]) -> Dict[str, float]:
"""Build a params dict by providing free param vals, combine with fixed param vals.
Takes free parameter float values (which can be from any source e.g. MAP results, MCMC posteriors,
or any custom values) and combines them with the fixed parameter values to create
a complete parameter dictionary. This dict is ideal for calculating chi2, log-likelihood,
AICc, and BIC.
This is designed for a single value per parameter. For combining the MCMC posterior
chains for free parameters and the fixed values for fixed parameters, use
`get_mcmc_posterior_dict` method.
Parameters
----------
free_params : list, np.ndarray, or dict
Free parameter values from any source:
- list/array: values in order of self.free_params_names
- dict: mapping of free param names to values
Returns
-------
Dict[str, float]
Complete parameters dict with both free and fixed parameter values
Examples
--------
>>> # From MAP optimization result
>>> map_result = fitter.find_map_estimate()
>>> params = fitter.build_params_dict(map_result.x)
>>> aicc = fitter.calculate_aicc(params)
>>>
>>> # From best MCMC sample
>>> best_sample = fitter.get_sample_with_best_lnprob(discard_start=1000)
>>> params = fitter.build_params_dict(best_sample)
>>> bic = fitter.calculate_bic(params)
>>>
>>> # From custom array of values (in order of free_params_names)
>>> custom_values = [5.0, 50.0, 0.1, 0.0, 2450000.0] # example values
>>> params = fitter.build_params_dict(custom_values)
>>> log_like = fitter.calculate_log_likelihood(params)
"""
if isinstance(free_params, dict):
# Validate that all expected free parameters are present
expected_names = set(self.free_params_names)
provided_names = set(free_params.keys())
missing = expected_names - provided_names
if missing:
raise ValueError(f"Missing required free parameters: {missing}")
extra = provided_names - expected_names
if extra:
raise ValueError(f"Unexpected parameters provided: {extra}")
return self.fixed_params_values_dict | free_params
else:
# Validate that array/list has correct length
if len(free_params) != len(self.free_params_names):
raise ValueError(
f"Expected {len(self.free_params_names)} free parameter values "
f"but got {len(free_params)} "
f"(expecting {len(self.free_params_names)} values for {self.free_params_names})"
)
free_dict = dict(zip(self.free_params_names, free_params))
return self.fixed_params_values_dict | free_dict
[docs]
def calculate_chi2(self, params_dict: Dict[str, float]) -> float:
r"""Calculate chi-squared for given parameter values.
Uses LogLikelihood to avoid code duplication. Works backwards from
log-likelihood:
.. math::
\ell = -\frac{1}{2} \left( \chi^2 + \text{penalty} \right)
Parameters
----------
params_dict : dict
Dictionary of all parameter values (both fixed and free parameters)
Returns
-------
float
Chi-squared value:
.. math::
\chi^2 = \sum_i \frac{(d_i - m_i)^2}{\sigma_i^2 + \sigma_{\text{jit}}^2}
"""
# Create LogLikelihood instance to reuse RV model calculation
ll = LogLikelihood(
self.time, self.vel, self.velerr,
self.instrument, self.unique_instruments, self.t0,
self.planet_letters, self.parameterisation
)
# Get log-likelihood
log_like = ll(params_dict)
# Work backwards to get chi2
# ll = -0.5 * (chi2 + penalty_term)
# chi2 = -2 * ll - penalty_term
# Calculate per-instrument jitter for each observation
velerr_jitter_squared = np.zeros_like(self.velerr)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit = params_dict[f"jit_{inst}"]
velerr_jitter_squared[mask] = self.velerr[mask]**2 + jit**2
penalty_term = np.sum(np.log(2 * np.pi * velerr_jitter_squared))
chi2 = -2 * log_like - penalty_term
return chi2
[docs]
def calculate_aicc(self, params_dict: Dict[str, float]) -> float:
r"""Calculate corrected Akaike Information Criterion (AICc).
.. math::
\text{AICc} = 2k - 2\ln\mathcal{L} + \frac{2k^2 + 2k}{n - k - 1}
where :math:`k` is the number of free parameters, :math:`n` is the
number of observations, and :math:`\mathcal{L}` is the likelihood.
Converges to AIC for large :math:`n`.
Parameters
----------
params_dict : dict
Dictionary of all parameter values (both fixed and free parameters)
Returns
-------
float
AICc value
"""
k = self.ndim
n = len(self.time)
log_like = self.calculate_log_likelihood(params_dict)
aic = 2 * k - 2 * log_like # traditional AIC
correction = (2 * k**2 + 2 * k) / (n - k - 1) # small-sample correction
return aic + correction
[docs]
def calculate_bic(self, params_dict: Dict[str, float]) -> float:
r"""Calculate Bayesian Information Criterion (BIC) for given parameters.
.. math::
\text{BIC} = k \ln n - 2 \ln \mathcal{L}
where :math:`k` is the number of free parameters, :math:`n` is the
number of observations, and :math:`\mathcal{L}` is the likelihood.
Parameters
----------
params_dict : dict
Dictionary of all parameter values (both fixed and free parameters)
Returns
-------
float
BIC value
"""
log_like = self.calculate_log_likelihood(params_dict)
return self.ndim * np.log(len(self.time)) - 2 * log_like
[docs]
def get_sample_with_best_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> Dict[str, float]:
"""Get parameter values from the MCMC sample with the highest log probability.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
Returns
-------
Dict[str, float]
Dictionary of parameter names to values from the best sample
"""
# Get samples and log probabilities
samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
lnprob = self.get_sampler_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
# Find index of maximum log probability
best_idx = np.argmax(lnprob)
best_lnprob = lnprob[best_idx]
print(f"Best sample found with log probability {best_lnprob:.6f} at index {best_idx} of samples (with discard_start={discard_start}, discard_end={discard_end}, thin={thin})")
# Get parameter values at that index
best_values = samples[best_idx]
# Return as dictionary
return dict(zip(self.free_params_names, best_values))
[docs]
def plot_autocorr_estimates(
self,
params: list[str] | None = None,
plot_mean: bool = False,
show_legend: bool = True,
title: str | None = "Autocorrelation Time Estimates",
xlabel: str | None = "Step number",
ylabel: str | None = r"Autocorrelation time $\tau$",
save: bool = False,
fname: str = "autocorr_plot.png",
dpi: int = 100
) -> None:
r"""Plot autocorrelation time estimates from adaptive MCMC run.
Shows how autocorrelation time evolved during the MCMC run and
the convergence threshold line (N / 50).
Only available if run_mcmc was called with check_convergence=True.
Parameters
----------
params : list[str] or None, optional
List of parameter names to plot. If None, plots all free parameters (default: None)
plot_mean : bool, optional
If True, plot mean tau instead of individual parameter taus.
Overrides params argument (default: False)
show_legend : bool, optional
Whether to show legend (default: True)
title : str or None, optional
Plot title (default: "Autocorrelation Time Estimates"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label (default: "Step number"). Set to None or "" to skip.
ylabel : str or None, optional
Y-axis label (default: r"Autocorrelation time $\tau$"). Set to None or "" to skip.
save : bool, optional
Save the plot to path `fname` (default: False)
fname : str, optional
The path to save the plot to (default: "autocorr_plot.png")
dpi : int, optional
The dpi to save the image at (default: 100)
Raises
------
ValueError
If no autocorrelation history is available (run_mcmc was not called
with check_convergence=True, or has not been called yet)
"""
# Check if data available
if not hasattr(self, 'autocorr_history') or len(self.autocorr_history) == 0:
raise ValueError(
"No autocorrelation history available. "
"Please run run_mcmc() with check_convergence=True first."
)
iterations = np.array(list(self.autocorr_history.keys()))
max_iteration = np.max(iterations)
tau_history = np.array(list(self.autocorr_history.values())) # Shape: (n_checks, n_params)
# Create plot
fig, ax = plt.subplots(1, figsize=(10, 6))
if title:
fig.suptitle(title)
# Plot convergence threshold (N/50)
ax.plot([0, max_iteration], [0, max_iteration / 50], "--k", linewidth=2,
label="N/50")
if plot_mean:
# Plot mean tau
mean_tau = np.mean(tau_history, axis=1)
ax.plot(iterations, mean_tau, linewidth=2, label="Mean τ")
else:
# Determine which parameters to plot
if params is None:
params_to_plot = self.free_params_names
indices_to_plot = range(len(self.free_params_names))
else:
params_to_plot = []
indices_to_plot = []
for param in params:
if param in self.free_params_names:
idx = self.free_params_names.index(param)
params_to_plot.append(param)
indices_to_plot.append(idx)
else:
logging.warning(f"Parameter '{param}' not found in free parameters, skipping")
# Plot individual parameter taus
for i, param_name in zip(indices_to_plot, params_to_plot):
ax.plot(iterations, tau_history[:, i], alpha=0.7, label=param_key_to_latex(param_name))
ax.set_xlim(0, iterations.max())
ax.set_ylim(bottom=0)
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
if show_legend:
ax.legend(loc='upper left')
ax.grid(True, alpha=0.3)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def plot_chains(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, truths: list = None, title: str | None = "Chains plot", xlabel: str | None = "Step number", save: bool = False, fname: str = "chains_plot.png", dpi: int = 100) -> None:
"""Plot MCMC chains for all free parameters.
Displays the evolution of each free parameter across MCMC steps for all walkers.
Useful for diagnosing convergence, burn-in, and mixing of the MCMC chains.
Each parameter gets its own subplot showing all walker traces.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
truths : list, optional
List of true parameter values to overplot as horizontal lines.
Must match the number of free parameters. Use None for parameters
without known truth values (default: None)
title : str or None, optional
Plot title (default: "Chains plot"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label (default: "Step number"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "chains_plot.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
# Scale figure height to maintain consistent subplot size
subplot_height_inches = 1.25
fig, axes = plt.subplots(self.ndim, figsize=(10, self.ndim * subplot_height_inches),
sharex=True, constrained_layout=True)
if title:
fig.suptitle(title)
if self.ndim == 1:
axes = [axes]
if truths is not None:
if not len(truths) == self.ndim:
raise ValueError(f"Length of truths ({len(truths)}) must match number of free parameters ({self.ndim})")
samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=False)
for i in range(self.ndim):
ax = axes[i]
ax.set_xlim(0, len(samples))
ax.set_ylabel(param_key_to_latex(self.free_params_names[i]))
to_plot = samples[:, :, i]
ax.plot(to_plot, "k", alpha=0.3)
if truths is not None and truths[i] is not None:
ax.axhline(truths[i], color="tab:blue")
fig.align_ylabels(axes)
if xlabel:
axes[-1].set_xlabel(xlabel)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def plot_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = "Log Probability Traces", xlabel: str | None = "Step number", ylabel: str | None = "Log probability", save: bool = False, fname: str = "lnprob_plot.png", dpi: int = 100) -> None:
"""Plot log probability traces for all walkers.
Useful for diagnosing MCMC convergence and identifying problematic
walkers/parameters. You can use `discard_start` and `discard_end` to
focus in on specific steps in the chains.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
title : str or None, optional
Plot title (default: "Log Probability Traces"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label (default: "Step number"). Set to None or "" to skip.
ylabel : str or None, optional
Y-axis label (default: "Log probability"). Set to None or "" to skip.
save : bool, optional
Save the plot to path `fname` (default: False)
fname : str, optional
The path to save the plot to (default: "lnprob_plot.png")
dpi : int, optional
The dpi to save the image at (default: 100)
"""
fig, ax = plt.subplots(1, figsize=(10, 6))
if title:
fig.suptitle(title)
lnprobs = self.get_sampler_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=False)
nsteps, nwalkers = lnprobs.shape
for i in range(nwalkers):
to_plot = lnprobs[:, i]
ax.plot(to_plot, "k", alpha=0.3)
ax.set_xlim(0, nsteps)
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def plot_corner(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, plot_datapoints: bool = False, truths: list[float] = None, title: str | None = "Corner plots", save: bool = False, fname: str = "corner_plot.png", dpi: int = 100) -> None:
"""Create a corner plot of MCMC samples.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
plot_datapoints : bool, optional
Show individual data points in addition to contours (default: False)
truths : list of float, optional
True parameter values to overplot as vertical/horizontal lines (default: None).
Must match the order of free parameters if provided.
title : str or None, optional
Plot title (default: "Corner plots"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "corner_plot.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
param_labels = [param_key_to_latex(n) for n in self.free_params_names]
fig = corner.corner(
flat_samples, labels=param_labels, show_titles=True,
plot_datapoints=plot_datapoints, quantiles=[0.1585, 0.5, 0.8415],
truths=truths,
)
if title:
fig.suptitle(title)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def _plot_rv(self, params: Dict[str, float], title: str = "RV Model", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "rv_plot.png", dpi: int = 100) -> None:
"""Helper function to plot RV model with given parameters.
Parameters
----------
params : dict
Dictionary of parameter values (both free and fixed)
title : str, optional
Plot title (default: "RV Model"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "rv_plot.png")
dpi : int, optional
Resolution for saving (default: 100)
Returns
-------
np.ndarray
Time array used for evaluation
np.ndarray
RV values at evaluation times
"""
# Create smooth time curve for plotting
_tmin, _tmax = self.time.min(), self.time.max()
_trange = _tmax - _tmin
tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000)
# Initialize arrays for planetary contributions
rv_all_planets_smooth = np.zeros(len(tsmooth))
rv_all_planets_obs = np.zeros(len(self.time))
# Add planetary contributions
for letter in self.planet_letters:
planet_params = {}
for par_name in self.parameterisation.pars:
key = f"{par_name}_{letter}"
planet_params[par_name] = params[key]
planet = ravest.model.Planet(letter, self.parameterisation, planet_params)
# Calculate for both time arrays in single loop
rv_all_planets_smooth += planet.radial_velocity(tsmooth)
rv_all_planets_obs += planet.radial_velocity(self.time)
# Add trend contribution (no gamma offset - that's per-instrument)
trend_params = {"gd": params["gd"], "gdd": params["gdd"]}
trend = ravest.model.Trend(params=trend_params, t0=self.t0)
rv_total_smooth = rv_all_planets_smooth + trend.radial_velocity(tsmooth)
rv_total_obs = rv_all_planets_obs + trend.radial_velocity(self.time)
# Subtract per-instrument gamma offsets from data (so we compare gamma-corrected data to physical model)
vel_corrected = self.vel.copy()
for inst in self.unique_instruments:
mask = (self.instrument == inst)
vel_corrected[mask] -= params[f"g_{inst}"]
# Calculate per-instrument jitter for error bars
velerr_with_jit = np.zeros_like(self.velerr)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit = params[f"jit_{inst}"]
velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit**2)
# Calculate residuals (gamma-corrected data minus model)
residuals = vel_corrected - rv_total_obs
# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 5),
gridspec_kw={'height_ratios': [3, 1], 'hspace': 0})
# Main RV plot - plot data by instrument with different colours
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)}
for inst in self.unique_instruments:
mask = (self.instrument == inst)
ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=self.velerr[mask],
marker=".", color=inst_colors[inst], ecolor=inst_colors[inst],
linestyle="None", markersize=8, zorder=4, label=inst)
# Jitter extension (faded, no label)
ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=velerr_with_jit[mask],
marker="None", ecolor=inst_colors[inst], linestyle="None",
alpha=0.5, zorder=3)
ax1.plot(tsmooth, rv_total_smooth, label="Model", color="black", zorder=2)
ax1.set_xlim(tsmooth[0], tsmooth[-1])
if ylabel_main:
ax1.set_ylabel(ylabel_main)
if title:
ax1.set_title(title)
ax1.legend(loc="upper right")
ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in') # Remove x-axis labels from top plot
ax1.tick_params(axis='y', direction='in')
# Set y-axis ticks automatically based on data range
ax1.yaxis.set_major_locator(AutoLocator())
ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax1.tick_params(axis='y', which='minor', direction='in', length=3)
# Residuals plot
for inst in self.unique_instruments:
mask = (self.instrument == inst)
ax2.errorbar(self.time[mask], residuals[mask], yerr=self.velerr[mask],
marker=".", color=inst_colors[inst], ecolor=inst_colors[inst],
linestyle="None", markersize=8, zorder=4)
ax2.errorbar(self.time[mask], residuals[mask], yerr=velerr_with_jit[mask],
marker="None", ecolor=inst_colors[inst], linestyle="None",
alpha=0.5, zorder=3)
ax2.axhline(0, color="k", linestyle="--", zorder=2)
ax2.set_xlim(tsmooth[0], tsmooth[-1])
# Set symmetric y-limits for residuals
max_abs_residual = np.max(np.abs(residuals + velerr_with_jit))
ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1)
if xlabel:
ax2.set_xlabel(xlabel)
if ylabel_residuals:
ax2.set_ylabel(ylabel_residuals)
ax2.tick_params(axis='x', direction='in')
ax2.tick_params(axis='y', direction='in')
ax2.tick_params(axis='x', top=True, labeltop=False) # Add ticks on shared border
# Set y-axis ticks automatically based on residuals range
ax2.yaxis.set_major_locator(AutoLocator())
ax2.yaxis.set_minor_locator(AutoMinorLocator())
ax2.tick_params(axis='y', which='minor', direction='in', length=3)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def _plot_phase(self, planet_letter: str, params: Dict[str, float], title: str = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "phase_plot.png", dpi: int = 100) -> None:
"""Helper function to plot phase-folded RV model for a single planet with given parameters.
Parameters
----------
planet_letter : str
Letter identifying the planet to plot (e.g., 'b', 'c', 'd')
params : dict
Dictionary of parameter values (both free and fixed)
title : str, optional
Plot title (default: f"Planet {planet_letter} Phase Plot"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "phase_plot.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
if title is None:
title = f"Planet {planet_letter} Phase Plot"
# get smooth linear time curve for plotting
_tmin, _tmax = self.time.min(), self.time.max()
_trange = _tmax - _tmin
tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000)
# Calculate per-instrument jitter for error bars
velerr_with_jit = np.zeros_like(self.velerr)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit = params[f"jit_{inst}"]
velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit**2)
# Get period and time of conjunction for this planet
P = params[f"P_{planet_letter}"]
# Convert to tc if needed
if "Tc" in self.parameterisation.pars:
Tc = params[f"Tc_{planet_letter}"]
elif "e" in self.parameterisation.pars and "w" in self.parameterisation.pars:
_e = params[f"e_{planet_letter}"]
_w = params[f"w_{planet_letter}"]
_Tp = params[f"Tp_{planet_letter}"]
Tc = self.parameterisation.convert_tp_to_tc(_Tp, P, _e, _w)
else:
# Fall back to default parameterisation conversion
planet_params = {par: params[f"{par}_{planet_letter}"] for par in self.parameterisation.pars}
default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params)
Tc = self.parameterisation.convert_tp_to_tc(default_params["Tp"], P, default_params["e"], default_params["w"])
# Calculate phase-folded time arrays (in units of orbital phase)
t_fold_sorted, inds = ravest.model.fold_time_series(self.time, P, Tc)
tsmooth_fold_sorted, smooth_inds = ravest.model.fold_time_series(tsmooth, P, Tc)
# Calculate RV contribution from this planet only
planet_params = {par: params[f"{par}_{planet_letter}"] for par in self.parameterisation.pars}
planet = ravest.model.Planet(planet_letter, self.parameterisation, planet_params)
planet_rv_tsmooth = planet.radial_velocity(tsmooth)
planet_rv_data = planet.radial_velocity(self.time)
planet_rv_sorted = planet_rv_tsmooth[smooth_inds]
# Calculate all other contributions (other planets + trend + gamma) at data times
other_rv = np.zeros(len(self.time))
for other_letter in self.planet_letters:
if other_letter != planet_letter:
other_params = {par: params[f"{par}_{other_letter}"] for par in self.parameterisation.pars}
other_planet = ravest.model.Planet(other_letter, self.parameterisation, other_params)
other_rv += other_planet.radial_velocity(self.time)
# Add trend (no gamma offset - that's per-instrument)
trend_params = {"gd": params["gd"], "gdd": params["gdd"]}
trend = ravest.model.Trend(params=trend_params, t0=self.t0)
other_rv += trend.radial_velocity(self.time)
# Add per-instrument gamma offsets
for inst in self.unique_instruments:
mask = (self.instrument == inst)
other_rv[mask] += params[f"g_{inst}"]
# Calculate data with other components subtracted
data_minus_others = self.vel - other_rv
# Sort the data according to phase folding
data_minus_others_sorted = data_minus_others[inds]
verr_sorted = self.velerr[inds]
velerr_with_jit_sorted = velerr_with_jit[inds]
instrument_sorted = self.instrument[inds]
# Calculate residuals (data - model for this planet)
residuals = data_minus_others - planet_rv_data
residuals_sorted = residuals[inds]
# Create figure with subplots (main plot + residuals)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 5),
gridspec_kw={'height_ratios': [3, 1], 'hspace': 0})
# Set up per-instrument colours
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)}
# Main phase plot - plot each instrument separately
for inst in self.unique_instruments:
mask = (instrument_sorted == inst)
ax1.errorbar(t_fold_sorted[mask], data_minus_others_sorted[mask], yerr=verr_sorted[mask],
marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst],
markersize=8, zorder=4, label=inst)
# Jitter extension (faded, no label)
ax1.errorbar(t_fold_sorted[mask], data_minus_others_sorted[mask], yerr=velerr_with_jit_sorted[mask],
marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3)
# Plot phase-folded model for this planet
ax1.plot(tsmooth_fold_sorted, planet_rv_sorted, label="Model", color="black", zorder=2)
ax1.set_xlim(-0.5, 0.5)
ax1.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25
if ylabel_main:
ax1.set_ylabel(ylabel_main)
ax1.legend(loc="upper right")
if title:
ax1.set_title(title)
ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in') # Remove x-axis labels from top plot
ax1.tick_params(axis='y', direction='in')
# Set y-axis ticks automatically based on phase data range
ax1.yaxis.set_major_locator(AutoLocator())
ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax1.tick_params(axis='y', which='minor', direction='in', length=3)
# Annotate with planet info
K_value = params[f"K_{planet_letter}"]
P_label = param_key_to_latex(f"P_{planet_letter}")
K_label = param_key_to_latex(f"K_{planet_letter}")
s = f"Planet {planet_letter}\n{P_label}={P:.2f} d\n{K_label}={K_value:.2f} m/s"
ax1.annotate(s, xy=(0, 1), xycoords="axes fraction",
xytext=(+0.5, -0.5), textcoords="offset fontsize", va="top")
# Residuals plot (phase-folded) - plot each instrument separately
for inst in self.unique_instruments:
mask = (instrument_sorted == inst)
ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=verr_sorted[mask],
marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst],
markersize=8, zorder=4)
ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=velerr_with_jit_sorted[mask],
marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3)
ax2.axhline(0, color="k", linestyle="--", zorder=2)
ax2.set_xlim(-0.5, 0.5)
ax2.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25
# Set symmetric y-limits for residuals
max_abs_residual = np.max(np.abs(residuals_sorted + velerr_with_jit_sorted))
ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1)
if xlabel:
ax2.set_xlabel(xlabel)
if ylabel_residuals:
ax2.set_ylabel(ylabel_residuals)
ax2.tick_params(axis='x', direction='in')
ax2.tick_params(axis='y', direction='in')
ax2.tick_params(axis='x', top=True, labeltop=False) # Add ticks on shared border
# Set y-axis ticks automatically based on residuals range
ax2.yaxis.set_major_locator(AutoLocator())
ax2.yaxis.set_minor_locator(AutoMinorLocator())
ax2.tick_params(axis='y', which='minor', direction='in', length=3)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def plot_posterior_rv(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, show_CI: bool = True, title: str | None = "Posterior RV", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "posterior_rv.png", dpi: int = 100) -> None:
"""Plot the posterior RV model with uncertainty bands from MCMC samples.
Calculates RV model predictions for each MCMC sample, then plots the median
with optional 68% CI (16th-84th percentile) uncertainty bands. Shows both
the full model and residuals vs data.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
show_CI : bool, optional
Show 68.3% credible interval band (default: True)
title : str or None, optional
Title for the main RV plot (default: "Posterior RV"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot to path `fname` (default: False)
fname : str, optional
The path to save the plot to (default: "posterior_rv.png")
dpi : int, optional
The dpi to save the image at (default: 100)
"""
# Create smooth time curve for plotting (same as _plot_rv helper)
_tmin, _tmax = self.time.min(), self.time.max()
_trange = _tmax - _tmin
tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000)
# Calculate posterior RV predictions (planets + trend, no gamma)
rv_all_planets_trend_matrix_smooth = self.calculate_rv_total_from_samples(times=tsmooth, discard_start=discard_start, discard_end=discard_end, thin=thin)
rv_all_planets_trend_matrix_obs = self.calculate_rv_total_from_samples(times=self.time, discard_start=discard_start, discard_end=discard_end, thin=thin)
# Calculate percentiles
rv_percentiles_smooth = np.percentile(rv_all_planets_trend_matrix_smooth, [15.85, 50, 84.15], axis=0)
rv_percentiles_obs = np.percentile(rv_all_planets_trend_matrix_obs, [15.85, 50, 84.15], axis=0)
# Get samples dict for per-instrument gamma and jitter
samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin)
# Subtract per-instrument gamma offsets from data (compare gamma-corrected data to physical model)
vel_corrected = self.vel.copy()
for inst in self.unique_instruments:
mask = (self.instrument == inst)
g_key = f"g_{inst}"
if g_key in samples_dict:
g_median = np.median(samples_dict[g_key])
else:
g_median = self.fixed_params_values_dict[g_key]
vel_corrected[mask] -= g_median
# Calculate residuals using median model at data times
residuals = vel_corrected - rv_percentiles_obs[1]
# Calculate per-instrument jitter for error bars
velerr_with_jit = np.zeros_like(self.velerr)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit_key = f"jit_{inst}"
if jit_key in samples_dict:
jit_median = np.median(samples_dict[jit_key])
else:
jit_median = self.fixed_params_values_dict[jit_key]
velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit_median**2)
# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 5), gridspec_kw={'height_ratios': [3, 1], 'hspace': 0})
# Main RV plot - plot data by instrument with different colours
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)}
for inst in self.unique_instruments:
mask = (self.instrument == inst)
ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=self.velerr[mask],
marker=".", color=inst_colors[inst], ecolor=inst_colors[inst],
linestyle="None", markersize=8, zorder=4, label=inst)
# Jitter extension (faded, no label)
ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=velerr_with_jit[mask],
marker="None", ecolor=inst_colors[inst], linestyle="None",
alpha=0.5, zorder=3)
# Plot median model and uncertainty
ax1.plot(tsmooth, rv_percentiles_smooth[1], label="Model", color="black", zorder=2)
if show_CI:
ax1.fill_between(tsmooth, rv_percentiles_smooth[0], rv_percentiles_smooth[2], color="tab:gray", alpha=0.3, edgecolor="none", label="68.3% CI", zorder=1)
ax1.set_xlim(tsmooth[0], tsmooth[-1])
if ylabel_main:
ax1.set_ylabel(ylabel_main)
if title:
ax1.set_title(title)
ax1.legend(loc="upper right")
ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in')
ax1.tick_params(axis='y', direction='in')
# Set y-axis ticks automatically based on data range
ax1.yaxis.set_major_locator(AutoLocator())
ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax1.tick_params(axis='y', which='minor', direction='in', length=3)
# Residuals plot
for inst in self.unique_instruments:
mask = (self.instrument == inst)
ax2.errorbar(self.time[mask], residuals[mask], yerr=self.velerr[mask],
marker=".", color=inst_colors[inst], ecolor=inst_colors[inst],
linestyle="None", markersize=8, zorder=4)
ax2.errorbar(self.time[mask], residuals[mask], yerr=velerr_with_jit[mask],
marker="None", ecolor=inst_colors[inst], linestyle="None",
alpha=0.5, zorder=3)
ax2.axhline(0, color="k", linestyle="--", zorder=2)
ax2.set_xlim(tsmooth[0], tsmooth[-1])
# Set symmetric y-limits for residuals plot, so 0 is in centre
max_abs_residual = np.max(np.abs(residuals + velerr_with_jit))
ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1)
if xlabel:
ax2.set_xlabel(xlabel)
if ylabel_residuals:
ax2.set_ylabel(ylabel_residuals)
ax2.tick_params(axis='x', direction='in')
ax2.tick_params(axis='y', direction='in')
ax2.tick_params(axis='x', top=True, labeltop=False)
# Set y-axis ticks automatically based on residuals range
ax2.yaxis.set_major_locator(AutoLocator())
ax2.yaxis.set_minor_locator(AutoMinorLocator())
ax2.tick_params(axis='y', which='minor', direction='in', length=3)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def plot_posterior_phase(self, planet_letter: str, discard_start: int = 0, discard_end: int = 0, thin: int = 1, show_CI: bool = True, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "posterior_phase.png", dpi: int = 100) -> None:
"""Plot phase-folded RV model with uncertainty bands from MCMC samples.
Shows the phase-folded planetary signal with uncertainty bands calculated
from MCMC samples. Removes contributions from trends and other planets.
Parameters
----------
planet_letter : str
Letter identifying the planet to plot (e.g., 'b', 'c', 'd')
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
show_CI : bool, optional
Show 68.3% credible interval band (default: True)
title : str or None, optional
Title for the main phase plot (default: "Posterior Phase Plot - Planet {planet_letter}"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot to path `fname` (default: False)
fname : str, optional
The path to save the plot to (default: "posterior_phase.png")
dpi : int, optional
The dpi to save the image at (default: 100)
"""
# Get period (handle both free and fixed cases)
samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin)
# Combine with fixed parameters
params = samples_dict | self.fixed_params_values_dict
# Create smooth time array (same approach as _plot_rv)
_tmin, _tmax = self.time.min(), self.time.max()
_trange = _tmax - _tmin
tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000)
# Calculate per-instrument jitter for error bars
velerr_with_jit = np.zeros_like(self.velerr)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit_key = f"jit_{inst}"
if jit_key in samples_dict:
jit_med = np.median(samples_dict[jit_key])
else:
jit_med = self.fixed_params_values_dict[jit_key]
velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit_med**2)
# Get period value
_P = params[f'P_{planet_letter}']
# Get (or calculate) Tc for this planet for folding around
if "Tc" in self.parameterisation.pars:
_Tc = params[f"Tc_{planet_letter}"]
else:
# Fall back to default parameterisation conversion (as it has P, e, w and Tp, so we can definitely get Tc)
planet_params = {par: params[f"{par}_{planet_letter}"] for par in self.parameterisation.pars}
default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params)
_Tc = self.parameterisation.convert_tp_to_tc(default_params["Tp"], _P, default_params["e"], default_params["w"])
# just for the folding, take the median value of the P and Tc samples
Tc_med = np.median(_Tc)
P_med = np.median(_P)
# Phase fold both data and smooth times
t_fold, inds = ravest.model.fold_time_series(self.time, P_med, Tc_med)
tsmooth_fold_sorted, smooth_inds = ravest.model.fold_time_series(tsmooth, P_med, Tc_med)
# Calculate RV components from MCMC samples (matrix of n_samples x n_obs, or n_samples x n_smooth)
rv_planet_data = self.calculate_rv_planet_from_samples(planet_letter, self.time, discard_start, discard_end, thin)
rv_planet_smooth = self.calculate_rv_planet_from_samples(planet_letter, tsmooth, discard_start, discard_end, thin)
rv_trend_data = self.calculate_rv_trend_from_samples(self.time, discard_start, discard_end, thin)
# Calculate RV contributions from all OTHER planets (not the target planet)
rv_other_planets_data = np.zeros_like(rv_trend_data)
for other_letter in self.planet_letters:
if other_letter != planet_letter:
rv_other_planet = self.calculate_rv_planet_from_samples(other_letter, self.time, discard_start, discard_end, thin)
rv_other_planets_data += rv_other_planet
# Combine all non-target contributions (other planets + trend)
rv_others_total_data = rv_trend_data + rv_other_planets_data
# Add per-instrument gamma offsets (using median values)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
g_key = f"g_{inst}"
if g_key in samples_dict:
g_median = np.median(samples_dict[g_key])
else:
g_median = self.fixed_params_values_dict[g_key]
rv_others_total_data[:, mask] += g_median
# Calculate percentiles across these matrices of RVs
rv_planet_data_percs = np.percentile(rv_planet_data, [15.85, 50, 84.15], axis=0)
rv_planet_smooth_percs = np.percentile(rv_planet_smooth, [15.85, 50, 84.15], axis=0)
rv_others_total_percs = np.percentile(rv_others_total_data, [15.85, 50, 84.15], axis=0)
# Remove all other contributions from observed data (using median of combined other contributions)
data_minus_others = self.vel - rv_others_total_percs[1]
# That gives us the component of the data just due to this planet. So now we can see how our modelled planet compares
residuals = data_minus_others - rv_planet_data_percs[1]
# Create figure with subplots (main plot + residuals)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 5), sharex=True,
gridspec_kw={'height_ratios': [3, 1], 'hspace': 0})
# Set up per-instrument colours
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)}
instrument_sorted = self.instrument[inds]
# Main phase plot - plot data with other contributions removed, sorted by phase
for inst in self.unique_instruments:
mask = (instrument_sorted == inst)
ax1.errorbar(t_fold[mask], data_minus_others[inds][mask], yerr=self.velerr[inds][mask],
marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst],
markersize=8, zorder=4, label=inst)
# Jitter extension (faded, no label)
ax1.errorbar(t_fold[mask], data_minus_others[inds][mask], yerr=velerr_with_jit[inds][mask],
marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3)
# Plot planet model with uncertainty, sorted by phase
ax1.plot(tsmooth_fold_sorted, rv_planet_smooth_percs[1][smooth_inds],
linestyle="-", color="black", zorder=2, label="Model")
if show_CI:
ax1.fill_between(tsmooth_fold_sorted, rv_planet_smooth_percs[0][smooth_inds],
rv_planet_smooth_percs[2][smooth_inds], color="tab:gray", alpha=0.3, edgecolor="none", label="68.3% CI", zorder=1)
ax1.set_xlim(-0.5, 0.5)
ax1.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25
if ylabel_main:
ax1.set_ylabel(ylabel_main)
if title is None:
ax1.set_title(f"Posterior Phase Plot - Planet {planet_letter}")
elif title:
ax1.set_title(title)
ax1.legend(loc="upper right")
ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in')
ax1.tick_params(axis='y', direction='in')
# Set y-axis ticks automatically based on phase data range
ax1.yaxis.set_major_locator(AutoLocator())
ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax1.tick_params(axis='y', which='minor', direction='in', length=3)
# Residuals plot (phase-folded) - plot each instrument separately
for inst in self.unique_instruments:
mask = (instrument_sorted == inst)
ax2.errorbar(t_fold[mask], residuals[inds][mask], yerr=self.velerr[inds][mask],
marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst],
markersize=8, zorder=4)
ax2.errorbar(t_fold[mask], residuals[inds][mask], yerr=velerr_with_jit[inds][mask],
marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3)
ax2.axhline(0, color="k", linestyle="--", zorder=2)
ax2.set_xlim(-0.5, 0.5)
ax2.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25
# Set symmetric y-limits for residuals
max_abs_residual = np.max(np.abs(residuals[inds] + velerr_with_jit[inds]))
ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1)
if xlabel:
ax2.set_xlabel(xlabel)
if ylabel_residuals:
ax2.set_ylabel(ylabel_residuals)
ax2.tick_params(axis='x', direction='in')
ax2.tick_params(axis='y', direction='in')
ax2.tick_params(axis='x', top=True, labeltop=False)
# Set y-axis ticks automatically based on residuals range
ax2.yaxis.set_major_locator(AutoLocator())
ax2.yaxis.set_minor_locator(AutoMinorLocator())
ax2.tick_params(axis='y', which='minor', direction='in', length=3)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def calculate_rv_planet_from_samples(self, planet_letter: str, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = False) -> np.ndarray:
"""Calculate planetary RV for each MCMC sample.
This calculates RV(params_i) for each MCMC sample i, preserving
parameter correlations. This differs from using median parameters
which may not represent actual samples from the posterior.
Parameters
----------
planet_letter : str
Planet letter (e.g., 'b', 'c')
times : np.ndarray
Time points to calculate RV at
discard_start : int, optional
Discard first N steps (default: 0)
discard_end : int, optional
Discard last N steps (default: 0)
thin : int, optional
Use every Nth sample (default: 1)
progress : bool, optional
Show progress bar (default: False)
Returns
-------
np.ndarray
Shape (n_samples, len(times)) - RV for each sample
"""
samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
planet_rvs = np.zeros((len(samples), len(times)))
iterator = tqdm(enumerate(samples), total=len(samples), disable=not progress, desc=f"Calculating planet {planet_letter} RV from samples")
for i, row in iterator:
# Build complete params dict for this sample
params = self.build_params_dict(row)
# Use custom method
planet_rvs[i, :] = self.calculate_rv_planet_custom(planet_letter, times, params)
return planet_rvs
[docs]
def calculate_rv_trend_from_samples(self, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = False) -> np.ndarray:
"""Calculate trend RV for each MCMC sample.
This calculates RV_trend(params_i) for each MCMC sample i, preserving
parameter correlations. This differs from using median parameters
which may not represent actual samples from the posterior.
Parameters
----------
times : np.ndarray
Time points to calculate RV at
discard_start : int, optional
Discard first N steps (default: 0)
discard_end : int, optional
Discard last N steps (default: 0)
thin : int, optional
Use every Nth sample (default: 1)
progress : bool, optional
Show progress bar (default: False)
Returns
-------
np.ndarray
Shape (n_samples, len(times)) - Trend RV for each sample
"""
samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
trend_rvs = np.zeros((len(samples), len(times)))
iterator = tqdm(enumerate(samples), total=len(samples), disable=not progress, desc="Calculating trend RV from samples")
for i, row in iterator:
# Build complete params dict for this sample
params = self.build_params_dict(row)
# Use custom method
trend_rvs[i, :] = self.calculate_rv_trend_custom(times, params)
return trend_rvs
[docs]
def calculate_rv_total_from_samples(self, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = False) -> np.ndarray:
"""Calculate total RV (planets + trend) for each MCMC sample.
This calculates RV_total(params_i) for each MCMC sample i, preserving
parameter correlations. This differs from using median parameters
which may not represent actual samples from the posterior.
Parameters
----------
times : np.ndarray
Time points to calculate RV at
discard_start : int, optional
Discard first N steps (default: 0)
discard_end : int, optional
Discard last N steps (default: 0)
thin : int, optional
Use every Nth sample (default: 1)
progress : bool, optional
Show progress bar (default: False)
Returns
-------
np.ndarray
Shape (n_samples, len(times)) - Total RV for each sample
"""
# Get trend RV for all samples
total_rvs = self.calculate_rv_trend_from_samples(times, discard_start, discard_end, thin, progress)
# Add each planet's RV
for planet_letter in self.planet_letters:
planet_rvs = self.calculate_rv_planet_from_samples(planet_letter, times, discard_start, discard_end, thin, progress)
total_rvs += planet_rvs
return total_rvs
[docs]
def calculate_rv_planet_custom(self, planet_letter: str, times: np.ndarray, params: dict[str, float]) -> np.ndarray:
"""Calculate planetary RV for a single set of custom parameters.
Useful for calculating RV with specific parameter values (e.g., best lnprob
sample, median parameters, or experimental values).
Parameters
----------
planet_letter : str
Planet letter (e.g., 'b', 'c')
times : np.ndarray
Time points to calculate RV at
params : dict[str, float]
Complete parameter dictionary (both free and fixed parameters).
Can be created using build_params_dict() or manually constructed.
Returns
-------
np.ndarray
RV values at the requested times
Examples
--------
>>> # Using MAP result
>>> map_result = fitter.find_map_estimate()
>>> params = fitter.build_params_dict(map_result.x)
>>> rv = fitter.calculate_rv_planet_custom('b', times, params)
>>>
>>> # Using best lnprob sample
>>> best_params = fitter.get_sample_with_best_lnprob(discard_start=1000)
>>> params = fitter.build_params_dict(best_params)
>>> rv = fitter.calculate_rv_planet_custom('b', times, params)
"""
# Extract planet parameters
planet_params = {}
for par in self.parameterisation.pars:
key = f"{par}_{planet_letter}"
planet_params[par] = params[key]
# Calculate planet RV
planet = ravest.model.Planet(planet_letter, self.parameterisation, planet_params)
return planet.radial_velocity(times)
[docs]
def calculate_rv_trend_custom(self, times: np.ndarray, params: dict[str, float]) -> np.ndarray:
"""Calculate trend RV for a single set of custom parameters.
Useful for calculating RV with specific parameter values (e.g., best lnprob
sample, median parameters, or experimental values).
Parameters
----------
times : np.ndarray
Time points to calculate RV at
params : dict[str, float]
Complete parameter dictionary (both free and fixed parameters).
Can be created using build_params_dict() or manually constructed.
Returns
-------
np.ndarray
Trend RV values at the requested times
Examples
--------
>>> # Using MAP result
>>> map_result = fitter.find_map_estimate()
>>> params = fitter.build_params_dict(map_result.x)
>>> trend_rv = fitter.calculate_rv_trend_custom(times, params)
"""
# Calculate trend RV (no gamma offset - that's per-instrument)
trend = ravest.model.Trend(params={"gd": params["gd"], "gdd": params["gdd"]}, t0=self.t0)
return trend.radial_velocity(times)
[docs]
def calculate_rv_total_custom(self, times: np.ndarray, params: dict[str, float]) -> np.ndarray:
"""Calculate total RV (trend + all planets) for a single set of custom parameters.
Useful for calculating RV with specific parameter values (e.g., best lnprob
sample, median parameters, or experimental values).
Parameters
----------
times : np.ndarray
Time points to calculate RV at
params : dict[str, float]
Complete parameter dictionary (both free and fixed parameters).
Can be created using build_params_dict() or manually constructed.
Returns
-------
np.ndarray
Total RV values (trend + all planets) at the requested times
Examples
--------
>>> # Using MAP result
>>> map_result = fitter.find_map_estimate()
>>> params = fitter.build_params_dict(map_result.x)
>>> total_rv = fitter.calculate_rv_total_custom(times, params)
>>>
>>> # Using median parameters
>>> samples_df = fitter.get_samples_df(discard_start=1000)
>>> median_values = samples_df.median().to_dict()
>>> params = fitter.build_params_dict(median_values)
>>> total_rv = fitter.calculate_rv_total_custom(times, params)
"""
# Calculate trend
total_rv = self.calculate_rv_trend_custom(times, params)
# Add each planet
for planet_letter in self.planet_letters:
planet_rv = self.calculate_rv_planet_custom(planet_letter, times, params)
total_rv += planet_rv
return total_rv
[docs]
def plot_MAP_rv(self, map_result: scipy.optimize.OptimizeResult, title: str | None = "MAP RV", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "MAP_rv.png", dpi: int = 100) -> None:
"""Plot radial velocity data and model using MAP parameter estimates.
Parameters
----------
map_result : scipy.optimize.OptimizeResult
Result from find_map_estimate() containing the MAP parameters
title : str or None, optional
Plot title (default: "MAP RV"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "MAP_rv.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
# Get MAP parameter values from the optimization result
map_params = dict(zip(self.free_params_names, map_result.x))
# Combine with fixed parameters
all_params = self.fixed_params_values_dict | map_params
# Use helper function to create the plot
self._plot_rv(all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs]
def plot_MAP_phase(self, planet_letter: str, map_result: scipy.optimize.OptimizeResult, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "MAP_phase.png", dpi: int = 100) -> None:
"""Plot phase-folded radial velocity data and model using MAP parameter estimates.
Parameters
----------
planet_letter : str
Letter identifying the planet to plot (e.g., 'b', 'c', 'd')
map_result : scipy.optimize.OptimizeResult
Result from find_map_estimate() containing the MAP parameters
title : str or None, optional
Plot title (default: f"MAP Phase Plot - Planet {planet_letter}"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "MAP_phase.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
# Get MAP parameter values from the optimization result
map_params = dict(zip(self.free_params_names, map_result.x))
# Combine with fixed parameters
all_params = self.fixed_params_values_dict | map_params
# Set default title if not provided
if title is None:
title = f"MAP Phase Plot - Planet {planet_letter}"
# Use helper function to create the plot
self._plot_phase(planet_letter, all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs]
def plot_custom_rv(self, params: dict, title: str | None = "Custom RV Plot", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "custom_rv.png", dpi: int = 100) -> None:
"""Plot radial velocity data and model using custom parameter values.
Allows plotting with arbitrary parameter values for exploring parameter space
or comparing theoretical models.
Parameters
----------
params : dict
Dictionary of parameter values to use for plotting. Keys should match
parameter names, values should be floats. Must include all required
parameters for the current parameterisation.
title : str or None, optional
Plot title (default: "Custom RV Plot"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "custom_rv.png")
dpi : int, optional
Resolution for saving (default: 100)
Examples
--------
>>> # Plot with custom values (must include all required parameters)
>>> fitter.plot_custom_rv({"P_b": 4.25, "K_b": 55.0, "e_b": 0.1,
... "w_b": 1.57, "Tc_b": 2456325.5,
... "g_HARPS": -10.2, "jit_HARPS": 2.0,
... "gd": 0.0, "gdd": 0.0})
"""
# Validate that all required parameters are present
expected_params = set(self.free_params_names + list(self.fixed_params_names))
provided_params = set(params.keys())
missing_params = expected_params - provided_params
if missing_params:
raise ValueError(f"Missing required parameters: {missing_params}")
# Use helper function to create the plot
self._plot_rv(params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs]
def plot_custom_phase(self, planet_letter: str, params: dict, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "custom_phase.png", dpi: int = 100) -> None:
"""Plot phase-folded radial velocity data and model using custom parameter values.
Allows plotting phase-folded data with arbitrary parameter values for exploring
parameter space or comparing theoretical models.
Parameters
----------
planet_letter : str
Letter identifying the planet to plot (e.g., 'b', 'c', 'd')
params : dict
Dictionary of parameter values to use for plotting. Keys should match
parameter names, values should be floats. Must include all required
parameters for the current parameterisation.
title : str or None, optional
Plot title (default: f"Custom Phase Plot - Planet {planet_letter}"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "custom_phase.png")
dpi : int, optional
Resolution for saving (default: 100)
Examples
--------
>>> # Plot phase curve with custom values
>>> fitter.plot_custom_phase("b", {"P_b": 4.25, "K_b": 55.0, "e_b": 0.1,
... "w_b": 1.57, "Tc_b": 2456325.5,
... "g_HARPS": -10.2, "jit_HARPS": 2.0,
... "gd": 0.0, "gdd": 0.0})
"""
# Validate that all required parameters are present
expected_params = set(self.free_params_names + list(self.fixed_params_names))
provided_params = set(params.keys())
missing_params = expected_params - provided_params
if missing_params:
raise ValueError(f"Missing required parameters: {missing_params}")
# Set default title if not provided
if title is None:
title = f"Custom Phase Plot - Planet {planet_letter}"
# Use helper function to create the plot
self._plot_phase(planet_letter, params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs]
def plot_best_sample_rv(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = "Best Sample RV Plot", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "best_sample_rv.png", dpi: int = 100) -> None:
"""Plot radial velocity data and model using parameter values from the MCMC sample with highest log probability.
This is useful for comparing with plot_MAP_rv() to diagnose potential issues with
MAP convergence or MCMC mixing. The two plots should be very similar if both
MAP and MCMC are working correctly.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
title : str or None, optional
Title for the main RV plot (default: "Best Sample RV Plot"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "best_sample_rv.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
# Get parameter values from best sample
best_sample_params = self.get_sample_with_best_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin)
# Combine with fixed parameters
all_params = self.fixed_params_values_dict | best_sample_params
# Use helper function to create the plot
self._plot_rv(all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs]
def plot_best_sample_phase(self, planet_letter: str, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "best_sample_phase.png", dpi: int = 100) -> None:
"""Plot phase-folded radial velocity data and model using parameter values from the MCMC sample with highest log probability.
This is useful for comparing with plot_MAP_phase() to diagnose potential issues with
MAP convergence or MCMC mixing. The two plots should be very similar if both
MAP and MCMC are working correctly.
Parameters
----------
planet_letter : str
Letter identifying the planet to plot (e.g., 'b', 'c', 'd')
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
title : str or None, optional
Title for the main phase plot (default: "Best Sample Phase Plot - Planet {planet_letter}"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "best_sample_phase.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
# Get parameter values from best sample
best_sample_params = self.get_sample_with_best_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin)
# Combine with fixed parameters
all_params = self.fixed_params_values_dict | best_sample_params
# Set default title if not provided
if title is None:
title = f"Best Sample Phase Plot - Planet {planet_letter}"
# Use helper function to create the plot
self._plot_phase(planet_letter, all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals,
save=save, fname=fname, dpi=dpi)
[docs]
class LogPosterior:
"""Log posterior probability for MCMC sampling.
Combines log likelihood and log prior for Bayesian parameter estimation.
"""
def __init__(
self,
planet_letters: list[str],
parameterisation: Parameterisation,
priors: dict[str, Callable[[float], float]],
fixed_params: dict[str, float],
free_params_names: list[str],
time: np.ndarray,
vel: np.ndarray,
velerr: np.ndarray,
instrument: np.ndarray,
unique_instruments: np.ndarray,
t0: float,
) -> None:
"""Initialize the LogPosterior object.
Parameters
----------
planet_letters : list[str]
List of single-character planet identifiers.
parameterisation : Parameterisation
The orbital parameterisation to use.
priors : dict[str, Callable[[float], float]]
Dictionary mapping parameter names to their prior probability functions.
fixed_params : dict[str, float]
Dictionary of fixed parameter values.
free_params_names : list[str]
List of free parameter names to sample.
time : np.ndarray
Time of each observation [days].
vel : np.ndarray
Radial velocity at each time [m/s].
velerr : np.ndarray
Uncertainty on the radial velocity at each time [m/s].
instrument : np.ndarray
Instrument name for each observation.
unique_instruments : np.ndarray
Unique instrument names in the data.
t0 : float
Reference time for the trend [days].
"""
self.planet_letters = planet_letters
self.parameterisation = parameterisation
self.priors = priors
self.fixed_params = fixed_params
self.free_params_names = free_params_names
self.time = time
self.vel = vel
self.velerr = velerr
self.instrument = instrument
self.unique_instruments = unique_instruments
self.t0 = t0
# Create log-likelihood and log-prior objects for later
self.log_likelihood = LogLikelihood(
time=self.time,
vel=self.vel,
velerr=self.velerr,
instrument=self.instrument,
unique_instruments=self.unique_instruments,
t0=self.t0,
planet_letters=self.planet_letters,
parameterisation=self.parameterisation,
)
self.log_prior = LogPrior(self.priors)
[docs]
def _convert_params_for_prior_evaluation(self, free_params_dict: dict[str, float]) -> Dict[str, float]:
"""Convert free parameters for prior evaluation if needed.
Parameters
----------
free_params_dict : dict
Free parameters in current parameterisation
Returns
-------
dict
Parameters with names/values converted for prior evaluation
"""
# Three cases:
# Case 1: User is fitting in transformed parameterisation, but priors are in same transformed parameterisation
# Case 2: User is fitting in default parameterisation, and priors are also in default parameterisation
# Case 3: User is fitting in transformed parameterisation, but priors are in default parameterisation
# Simple detection: do prior keys match our current free parameter names?
prior_keys = set(self.priors.keys())
free_param_keys = set(self.free_params_names)
if prior_keys == free_param_keys:
# No conversion needed (Cases 1 & 2)
return free_params_dict
else:
# Conversion needed (Case 3) - convert to default parameterisation equivalents
# Start with just the non-planetary parameters that match
params_for_prior = {key: value for key, value in free_params_dict.items()
if key in prior_keys}
all_params = self.fixed_params | free_params_dict
# Convert each planet's parameters
for planet_letter in self.planet_letters:
# Get current planet parameters
planet_params = {par: all_params[f"{par}_{planet_letter}"]
for par in self.parameterisation.pars}
# Convert to default parameterisation
default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params)
# Add the converted parameter values for priors that need them
for default_par, value in default_params.items():
default_param_key = f"{default_par}_{planet_letter}"
if default_param_key in prior_keys: # Only add if we have a prior for it
params_for_prior[default_param_key] = value
return params_for_prior
[docs]
def log_probability(self, free_params_dict: Dict[str, float]) -> float:
"""Calculate log posterior probability for given free parameters.
Parameters
----------
free_params_dict : Dict[str, float]
Dictionary of free parameter values
Returns
-------
float
Log posterior probability (log likelihood + log prior)
"""
# Fast fail for invalid jitter (before expensive prior/likelihood calculations)
# We have to check jitter specifically because all other params will ultimately
# get checked/raise Exceptions when they are used to calculate an RV.
# Jitter doesn't directly contribute to calculated RV, so needs to be checked manually.
_all_params_for_ll = self.fixed_params | free_params_dict
for inst in self.unique_instruments:
if _all_params_for_ll[f"jit_{inst}"] < 0:
return -np.inf
# Evaluate priors on the free parameters. If any parameters are outside priors
# (i.e. priors are infinite), then fail fast by returning -inf early (before expensive likelihood calc).
# We attempt to convert free parameters (if needed) for prior evaluation
# This is for if the user is fitting in transformed parameterisation,
# but defining their priors in the default parameterisation
try:
params_for_prior = self._convert_params_for_prior_evaluation(free_params_dict)
lp = self.log_prior(params_for_prior)
except ValueError:
# Invalid parameter conversion (e.g., unphysical eccentricity)
return -np.inf
if not np.isfinite(lp):
return -np.inf
# Calculate log-likelihood with all parameters
ll = self.log_likelihood(_all_params_for_ll)
# Return combined log-posterior (log-likelihood + log-prior)
logprob = ll + lp
return logprob
[docs]
def _negative_log_probability_for_MAP(self, free_params_vals: list[float]) -> float:
"""For MAP: run __call__ only passing in a list, not dict, of params.
Because scipy.optimize.minimise only takes list of values, not a dict,
we need to assign the values back to their corresponding keys, and pass
that to __call__().
This does not check that the values are in the correct order, it is
assumed. As we're dealing with dicts, this hopefully is the case.
Parameters
----------
free_params_vals : list
float values of the free parameters
"""
# Create dicts from the names and values
# (Assumes the order of names matches the order of values)
free_params_dict = dict(zip(self.free_params_names, free_params_vals))
# Calculate *negative* log_probability (MAP is backwards from MCMC)
logprob = self.log_probability(free_params_dict)
neg_logprob = -logprob
# Handle -inf log_probability to prevent scipy RuntimeWarnings during optimisation
# scipy's optimizer can't handle -inf values in arithmetic operations
# (This does mean there is a non-zero chance we could end up returning a solution that doesn't satisfy the prior functions)
if not np.isfinite(neg_logprob):
return 1e30 # Very large finite number instead of +inf
return neg_logprob
[docs]
class LogLikelihood:
"""Log likelihood calculation for radial velocity data.
Calculates log likelihood given RV model parameters and data.
"""
def __init__(
self,
time: np.ndarray,
vel: np.ndarray,
velerr: np.ndarray,
instrument: np.ndarray,
unique_instruments: np.ndarray,
t0: float,
planet_letters: list[str],
parameterisation: Parameterisation,
) -> None:
"""Initialize the LogLikelihood object.
Parameters
----------
time : np.ndarray
Time of each observation [days].
vel : np.ndarray
Radial velocity at each time [m/s].
velerr : np.ndarray
Uncertainty on the radial velocity at each time [m/s].
instrument : np.ndarray
Instrument name for each observation.
unique_instruments : np.ndarray
Unique instrument names in the data.
t0 : float
Reference time for the trend [days].
planet_letters : list[str]
List of single-character planet identifiers.
parameterisation : Parameterisation
The orbital parameterisation to use.
"""
self.time = time
self.vel = vel
self.velerr = velerr
self.instrument = instrument
self.unique_instruments = unique_instruments
self.t0 = t0
self.planet_letters = planet_letters
self.parameterisation = parameterisation
# Precompute a per-observation integer index array.
# For each observation, store which instrument it came from as an integer:
# e.g. unique_instruments = ["HARPS", "ESPRESSO"]
# instrument = ["HARPS", "HARPS", "ESPRESSO", "HARPS", ...]
# _instrument_indices = [0, 0, 1, 0, ...]
# This lets us use numpy fancy indexing (array[integer_array]) to look up
# per-instrument values for all observations in one vectorised operation,
# instead of looping over instruments with boolean mask slices.
_inst_to_idx = {inst: i for i, inst in enumerate(self.unique_instruments)}
self._instrument_indices = np.array([_inst_to_idx[inst] for inst in self.instrument])
# Precompute parameter key strings for gamma and jitter lookups.
# These strings (e.g. "g_HARPS", "jit_ESPRESSO") are constant for the lifetime
# of this object — precomputing them avoids rebuilding f-strings on every call.
self._gamma_keys = [f"g_{inst}" for inst in self.unique_instruments]
self._jitter_keys = [f"jit_{inst}" for inst in self.unique_instruments]
# Precompute log(2*pi) — it's a constant, no need to recalculate every time
self._log_2pi = np.log(2 * np.pi)
# Precompute velerr squared — constant (observed data doesn't change) so no need to recalculate every time
self._velerr_sq = self.velerr ** 2
[docs]
def __call__(self, params: Dict[str, float]) -> float:
"""Calculate log likelihood for given parameters.
Parameters
----------
params : Dict[str, float]
Dictionary of all parameter values
Returns
-------
float
Log likelihood value
"""
rv_total = np.zeros(len(self.time))
# Step 1: Calculate RV contributions from each planet
for letter in self.planet_letters:
# get just the parameters for this planet (and strip the _letter suffix from the keys)
_this_planet_params = {
par: params[f"{par}_{letter}"]
for par in self.parameterisation.pars
}
try:
_this_planet = ravest.model.Planet(letter, self.parameterisation, _this_planet_params)
_this_planet_rv = _this_planet.radial_velocity(self.time)
except ValueError:
# Planet.__init__ validates parameters and raises ValueError for invalid params
return -np.inf # fail-fast: return -inf log-likelihood
# add this planet's RV contribution to the total
rv_total += _this_planet_rv
# Step 2: Calculate and add the RV from the system Trend (no gamma offset)
_trend_params = {"gd": params["gd"], "gdd": params["gdd"]}
_this_trend = ravest.model.Trend(params=_trend_params, t0=self.t0)
_rv_trend = _this_trend.radial_velocity(self.time)
rv_total += _rv_trend
# Step 3: Add per-instrument gamma offsets using vectorised fancy indexing.
# Build a small array of gamma values, one per instrument (length K), then use
# _instrument_indices to select the right gamma for each of the N observations.
# This gives a length-N array in one numpy operation — no Python loop needed.
gamma_per_instrument = np.array([params[k] for k in self._gamma_keys])
gamma_at_each_obs = gamma_per_instrument[self._instrument_indices]
rv_total += gamma_at_each_obs
# Step 4: Calculate log-likelihood with per-instrument jitter using vectorised fancy indexing.
# Each instrument has its own jitter value. We need to pair each of the N observations
# with its instrument's jitter. We do this in two steps:
# 1. Build a small array of jitter values, one per instrument (length K)
# 2. Use _instrument_indices to select the right jitter for each observation (length N)
# The result is a full-length array ready for vectorised arithmetic — no Python loop needed.
jitter_per_instrument = np.array([params[k] for k in self._jitter_keys])
jitter_at_each_obs = jitter_per_instrument[self._instrument_indices]
velerr_jit_sq = self._velerr_sq + jitter_at_each_obs**2
penalty_term = self._log_2pi + np.log(velerr_jit_sq)
residuals = rv_total - self.vel
chi2 = residuals**2 / velerr_jit_sq
ll = -0.5 * np.sum(chi2 + penalty_term)
return ll
[docs]
class LogPrior:
"""Log prior probability calculation.
Evaluates log prior probabilities for model parameters.
"""
def __init__(self, priors: dict[str, Callable[[float], float]]) -> None:
self.priors = priors
[docs]
def __call__(self, params: Dict[str, float]) -> float:
"""Calculate log prior probability for given parameters.
Parameters
----------
params : Dict[str, float]
Dictionary of parameter values
Returns
-------
float
Log prior probability
"""
log_prior_probability = 0
for param in params:
# go into the `self.priors dict``, get the Prior object for this `param`
# and call it with the value of said param, to get the prior probability
log_prior_probability += self.priors[param](params[param])
return log_prior_probability
[docs]
class GPFitter:
"""Gaussian Process fitter for exoplanet radial velocity data.
Similar interface to Fitter class, but uses Gaussian Processes to model
correlated noise in the data. The planetary RV model serves as the GP mean function.
Supports MCMC sampling, MAP estimation, and various parameterisations.
Handles multiple planets, trends, jitter parameters, and GP hyperparameters.
"""
def __init__(self, planet_letters: list[str], parameterisation: Parameterisation, gp_kernel: GPKernel) -> None:
"""Initialize the GPFitter object.
Parameters
----------
planet_letters : list[str]
List of single-character planet identifiers (e.g., ['b', 'c', 'd']).
Used to distinguish parameters for different planets in the system.
parameterisation : Parameterisation
The orbital parameterisation to use for fitting. Defines which orbital
elements are used as free/fixed parameters (e.g., 'Default', 'EccentricityWind').
gp_kernel : GPKernel
The Gaussian Process kernel to use for modelling correlated noise in the data.
"""
self.planet_letters = planet_letters
self.parameterisation = parameterisation
self.gp_kernel = gp_kernel
# Trigger numba JIT compilation before MCMC
_dummy_M = np.linspace(0, 2 * np.pi, 10)
_njit_kepler_rv(_dummy_M, 0.3, 10.0, 0.5)
# Initialize parameter storage
self._params: Dict[str, Parameter] = {}
self._priors: Dict[str, Callable[[float], float]]= {}
self._hyperparams: Dict[str, Parameter] = {}
self._hyperpriors: Dict[str, Callable[[float], float]] = {}
[docs]
def add_data(
self,
time: np.ndarray,
vel: np.ndarray,
velerr: np.ndarray,
instrument: np.ndarray,
t0: float,
) -> None:
"""Add the data to the GPFitter object.
Parameters
----------
time : array-like
Time of each observation [days]
vel : array-like
Radial velocity at each time [m/s]
velerr : array-like
Uncertainty on the radial velocity at each time [m/s]
instrument : array-like
Instrument name for each observation (e.g., "HARPS", "HIRES")
t0 : float
Reference time for the trend [days].
Recommended to set this as mean or median of input `time` array.
"""
if not (len(time) == len(vel) == len(velerr) == len(instrument)):
raise ValueError(
"Time, velocity, uncertainty, and instrument arrays must be the same length."
)
self.time = np.ascontiguousarray(time)
self.vel = np.ascontiguousarray(vel)
self.velerr = np.ascontiguousarray(velerr)
self.instrument = np.asarray(instrument)
self.unique_instruments = np.unique(self.instrument)
self.t0 = t0
@property
def params(self) -> Dict[str, Parameter]:
"""Parameters dictionary. Set via: gpfitter.params = param_dict."""
return self._params
@params.setter
def params(self, new_params: Dict[str, Parameter]) -> None:
"""Set parameters with a dict, checking all required params are present.
You can update all or some of the parameters at once, example:
>>> gpfitter.params = {"g": Parameter(1.0, "m/s"), "gd": Parameter(0.1, "m/s/d")} # only update trend parameters
>>> gpfitter.params = {"P_c": Parameter(5.0, "d"), "K_c": Parameter(3.5, "m/s")} # only update some of planet C parameters
Parameters
----------
new_params : dict
Dictionary of new parameter values to set.
The keys of this dictionary should match the parameter names expected
by the GPFitter object: all required parameters for the
chosen parameterisation, with planet letters (not required for
Trend or jitter parameters.)
Raises
------
ValueError
If any of the required parameters are missing or invalid.
"""
# Update the current _params dict with the new entries
merged_params = dict(self._params)
merged_params.update(new_params)
# Validate the complete parameter set
self._validate_complete_params(merged_params)
# If validation passes, update the actual params
self._params.update(new_params)
# Update ndim to total number of free parameters (hyperparams are added to ndim when they're set later on)
self.ndim = len(self.free_params_values)
if self.ndim == 0:
warnings.warn(
"All parameters are fixed. MCMC methods require at least one "
"free parameter or hyperparameter (fixed=False).",
UserWarning,
stacklevel=2
)
@property
def hyperparams(self) -> Dict[str, Parameter]:
"""Hyperparameters dictionary. Set via: gpfitter.hyperparams = hyperparam_dict."""
return self._hyperparams
@hyperparams.setter
def hyperparams(self, new_hyperparams: Dict[str, Parameter]) -> None:
"""Set hyperparameters with validation."""
# Update the current _hyperparams dict with the new entries
merged_hyperparams = dict(self._hyperparams)
merged_hyperparams.update(new_hyperparams)
# Validate using GPKernel (as the params required depend on the specific kernel)
self.gp_kernel.validate_hyperparams(merged_hyperparams)
# If validation passes, update the actual hyperparams
self._hyperparams.update(new_hyperparams)
# Update ndim to include hyperparameters (total free params + hyperparams)
self.ndim = len(self.free_params_values) + len(self.free_hyperparams_values)
if self.ndim == 0:
warnings.warn(
"All parameters and hyperparameters are fixed. MCMC methods require "
"at least one free parameter or hyperparameter (fixed=False).",
UserWarning,
stacklevel=2
)
@property
def priors(self) -> dict:
"""Priors dictionary. Set via: gpfitter.priors = prior_dict."""
return self._priors
@priors.setter
def priors(self, new_priors: dict[str, Callable[[float], float]]) -> None:
"""Set prior functions using a dict, checking all required priors are present.
Priors must be provided for all free parameters. You can set all priors
at once or update individual priors.
Parameters
----------
new_priors : dict
Dictionary of prior functions to set. Keys should be parameter names
that match free parameters, values should be callable prior functions.
Examples
--------
>>> from ravest.prior import Uniform
>>> gpfitter.priors = {"K_b": Uniform(0, 100), "P_b": Uniform(1, 30)}
Raises
------
ValueError
If any required priors are missing, unexpected prio rs are provided,
or initial parameter values are outside prior bounds.
"""
self._set_priors_with_validation(new_priors)
@property
def hyperpriors(self) -> dict:
"""Hyperpriors dictionary. Set via: gpfitter.hyperpriors = hyperprior_dict."""
return self._hyperpriors
@hyperpriors.setter
def hyperpriors(self, new_hyperpriors: dict[str, Callable[[float], float]]) -> None:
"""Set hyperprior functions with validation."""
self._set_hyperpriors_with_validation(new_hyperpriors)
[docs]
def _validate_complete_params(self, params: Dict[str, Parameter]) -> None:
"""Validate that params dict has required parameters, astrophysically valid values."""
# Build complete set of expected parameters
expected_params = set()
# Add planetary parameters
for planet_letter in self.planet_letters:
for par_name in self.parameterisation.pars:
expected_params.add(f"{par_name}_{planet_letter}")
# Add trend parameters (no gamma - that's per-instrument)
expected_params.update(["gd", "gdd"])
# Add per-instrument gamma offset and jitter parameters
for inst in self.unique_instruments:
expected_params.add(f"g_{inst}")
expected_params.add(f"jit_{inst}")
# Validate same as Fitter
provided_params = set(params.keys())
# Check for unexpected parameters
unexpected_params = provided_params - expected_params
if unexpected_params:
# Give a specific hint if user is passing legacy single-instrument parameters
legacy_params = unexpected_params & {"g", "jit"}
if legacy_params:
raise ValueError(
f"Unexpected parameters: {unexpected_params}. "
f"Single-instrument 'g' and 'jit' parameters are no longer supported. "
f"Use per-instrument names instead, e.g. "
f"{[f'g_{inst}' for inst in self.unique_instruments]} and "
f"{[f'jit_{inst}' for inst in self.unique_instruments]}, "
f"matching the instrument names passed to add_data()."
)
raise ValueError(
f"Unexpected parameters: {unexpected_params}. "
f"Expected {len(expected_params)} parameters, got {len(provided_params)}"
)
# Check for missing parameters
missing_params = expected_params - provided_params
if missing_params:
raise ValueError(
f"Missing required parameters: {missing_params}. "
f"Expected {len(expected_params)} parameters, got {len(provided_params)}"
)
# Validate astrophysical validity of all parameters
params_values = {name: param.value for name, param in params.items()}
self._validate_astrophysical_validity(params_values)
# Validate parameter coupling constraints
# i.e. if two parameters both need to be fixed or free together
self._validate_parameter_coupling(params)
[docs]
def _validate_astrophysical_validity(self, params_values: Dict[str, float]) -> None:
"""Validate that all parameter values are astrophysically valid."""
# First, check that ALL parameters are finite (not NaN or infinite)
invalid_params = { name: value for name, value in params_values.items() if not np.isfinite(value) }
if invalid_params:
raise ValueError( "Invalid parameters detected: " + ", ".join(f"{k}={v}" for k, v in invalid_params.items()) )
# Validate planetary parameters for each planet
for planet_letter in self.planet_letters:
planet_params = {}
for par_name in self.parameterisation.pars:
key = f"{par_name}_{planet_letter}"
planet_params[par_name] = params_values[key]
# Validate this planet's parameters in current parameterisation
self.parameterisation.validate_planetary_params(planet_params)
# Validate trend parameters are finite real numbers (already checked above, but kept for clarity)
for trend_param in ["gd", "gdd"]:
trend_value = params_values[trend_param]
if not np.isfinite(trend_value):
raise ValueError(f"Invalid trend parameter {trend_param}: {trend_value} is not a finite real number")
# Validate per-instrument parameters
for inst in self.unique_instruments:
# Gamma offset must be finite
g_key = f"g_{inst}"
if not np.isfinite(params_values[g_key]):
raise ValueError(f"Invalid gamma offset {g_key}: {params_values[g_key]} is not finite")
# Jitter must be >= 0
jit_key = f"jit_{inst}"
if params_values[jit_key] < 0:
raise ValueError(f"Invalid jitter {jit_key}: {params_values[jit_key]} < 0")
[docs]
def _validate_parameter_coupling(self, params: Dict[str, Parameter]) -> None:
"""Validate parameter coupling constraints (e.g., secosw/sesinw must both be free or both fixed)."""
for planet_letter in self.planet_letters:
# Check secosw/sesinw coupling
secosw_key = f"secosw_{planet_letter}"
sesinw_key = f"sesinw_{planet_letter}"
if secosw_key in params and sesinw_key in params:
secosw_fixed = params[secosw_key].fixed
sesinw_fixed = params[sesinw_key].fixed
if secosw_fixed != sesinw_fixed:
raise ValueError(f"Parameters {secosw_key} and {sesinw_key} must both be fixed or both be free")
# Check ecosw/esinw coupling
ecosw_key = f"ecosw_{planet_letter}"
esinw_key = f"esinw_{planet_letter}"
if ecosw_key in params and esinw_key in params:
ecosw_fixed = params[ecosw_key].fixed
esinw_fixed = params[esinw_key].fixed
if ecosw_fixed != esinw_fixed:
raise ValueError(f"Parameters {ecosw_key} and {esinw_key} must both be fixed or both be free")
[docs]
def _set_priors_with_validation(self, new_priors: dict[str, Callable[[float], float]]) -> None:
"""Set priors with validation. Supports partial updates. Can be current or default parameterisation."""
# Create merged priors dict (in case user is only updating some priors, not all)
merged_priors_dict = dict(self._priors) # get existing priors
merged_priors_dict.update(new_priors) # overwrite with newer functions, if supplied
provided_prior_param_names = set(merged_priors_dict.keys())
# There are two possibilities for priors:
# 1. The prior has been given for the parameter, in the current parameterisation
# (this can also include if the user is fitting in the default parameterisation)
# 2. The prior has been given for the Default parameterisation's equivalent parameter instead
# (e.g. e & w instead of secosw & sesinw, or tp instead of tc)
# If not, then prior isn't given for either the Current or Default parameterisation, raise an Exception
validated_priors = {}
missing_priors = []
conflicts = []
# in the current parameterisation, which (free) parameters do we expect priors for?
current_parameterisation_free_param_names = set(self.free_params_names)
for free_param_name in current_parameterisation_free_param_names:
if free_param_name in provided_prior_param_names:
# Prior was provided for the param in the current parameterisation
validated_priors[free_param_name] = merged_priors_dict[free_param_name]
# Check if user ALSO provided equivalent default priors (conflict!)
default_parameterisation_equivalent_free_param_names = self._get_default_parameterisation_equivalent_free_param_name(free_param_name)
if default_parameterisation_equivalent_free_param_names:
for equiv_param in default_parameterisation_equivalent_free_param_names:
if equiv_param in provided_prior_param_names:
conflicts.append((free_param_name, equiv_param))
else:
# We haven't been provided the prior for the free parameter in the current parameterisation
# So let's check if we were given the prior for the equivalent parameter in the default parameterisation instead
default_parameterisation_equivalent_free_param_names = self._get_default_parameterisation_equivalent_free_param_name(free_param_name)
# remember that one parameter in current parameterisation (e.g. secosw) might map to more than one equivalent in default parameterisation (e.g. both e & w)
if default_parameterisation_equivalent_free_param_names and all(eq in provided_prior_param_names for eq in default_parameterisation_equivalent_free_param_names):
# Found all required default equivalents
for equiv in default_parameterisation_equivalent_free_param_names:
validated_priors[equiv] = merged_priors_dict[equiv]
else:
# Missing prior for a free parameter in both the current parameterisation, and its equivalent in the default parameterisation
if default_parameterisation_equivalent_free_param_names:
missing_priors.append(f"{free_param_name} (or equivalent {default_parameterisation_equivalent_free_param_names})")
else:
missing_priors.append(free_param_name)
# Check for conflicts after processing all parameters
if conflicts:
conflict_strs = [f"{current} vs {default}" for current, default in conflicts]
raise ValueError(f"Conflicting priors provided for both current and default parameterisations: {', '.join(conflict_strs)}. Please provide priors for either the current parameterisation OR the equivalent default parameterisation, but not both.")
if missing_priors:
raise ValueError(f"Missing priors for parameters: {missing_priors}")
# Check for unexpected priors - only allow priors that were validated above
expected_prior_param_names = set(validated_priors.keys())
unexpected_prior_param_names = provided_prior_param_names - expected_prior_param_names
if unexpected_prior_param_names:
raise ValueError(
f"Unexpected priors supplied for parameters: {unexpected_prior_param_names}. "
f"Priors expected only for parameters: {expected_prior_param_names}"
)
# Check parameter values work with priors
self._check_params_values_against_priors(validated_priors, current_parameterisation_free_param_names)
# Update the priors with the new values
self._priors.update(new_priors)
[docs]
def _get_default_parameterisation_equivalent_free_param_name(self, free_param: str) -> Optional[list[str]]:
"""Get the names of the default parameterisation equivalent parameter(s), for a single free parameter from the current parameterisation.
Note this can be more than one: e.g. if you have secosw, this affects both e & w in the default parameterisation
Whereas Tc just maps to Tp alone.
Returns
-------
list[str] | None
- list[str]: equivalent parameter(s) in the default parameterisation
- None: no mapping needed / no alternative priors to look for
Raises
------
ValueError
If `free_param` is not a recognised planet, instrument, or trend parameter.
"""
# No underscore (expected to be a system trend parameter)
if '_' not in free_param:
if free_param in ['gd', 'gdd']:
# These are the same in all parameterisations
return None
else:
raise ValueError(f"Unknown free parameter: {free_param}")
# Contains underscore: Planetary or instrument parameters (with underscore before either planet letter or instrument name)
# e.g. P_b, or Tc_c, or jit_HARPS
else:
base_param, suffix = free_param.split('_', 1) # split only on first underscore (some instrument names may have underscores too)
# Planetary parameters: suffix is a planet letter
if suffix in self.planet_letters:
planet_letter = suffix
if base_param in ['secosw', 'sesinw']:
# Both secosw and sesinw map to e,w equivalents
partner_param = 'sesinw' if base_param == 'secosw' else 'secosw'
partner_key = f"{partner_param}_{planet_letter}"
if partner_key in self.free_params_names:
return [f"e_{planet_letter}", f"w_{planet_letter}"]
elif base_param in ['ecosw', 'esinw']:
# Both ecosw and esinw map to e,w equivalents
partner_param = 'esinw' if base_param == 'ecosw' else 'ecosw'
partner_key = f"{partner_param}_{planet_letter}"
if partner_key in self.free_params_names:
return [f"e_{planet_letter}", f"w_{planet_letter}"]
elif base_param == 'Tc':
# Tc can use Tp equivalent
return [f"Tp_{planet_letter}"]
elif base_param in ['P', 'K', 'e', 'w', 'Tp']:
# These are default parameterisation parameters anyway
return None
else:
# Suffix is a valid planet letter, but base parameter is unrecognised, so raise an error
raise ValueError(f"Free parameter {free_param} has known planet letter {planet_letter} but unrecognised base parameter {base_param}.")
# Instrument parameters: suffix is an instrument name
elif suffix in self.unique_instruments:
# The only instrument parameters are g and jit
if base_param in ['g', 'jit']:
# Per-instrument parameter (e.g., g_HARPS, jit_HIRES)
# These are the same in all parameterisations
return None
else:
raise ValueError(f"Free parameter {free_param} has known instrument name {suffix} but unrecognised base parameter {base_param} (expected 'g' or 'jit' only)")
# Unknown: Suffix is present, but not a planet letter or instrument, so raise an error
else:
raise ValueError(f"Free parameter {free_param} has unrecognised suffix {suffix}, expected one of planet letters {self.planet_letters} or instrument names {self.unique_instruments}.")
[docs]
def _check_params_values_against_priors(self, validated_priors: dict[str, Callable[[float], float]], current_free_param_names: list[str]) -> None:
"""Check parameter values against priors (including if Prior is for the Default parameterisation equivalent parameter)."""
for prior_param_name, prior_function in validated_priors.items():
if prior_param_name in current_free_param_names:
# This prior is in current parameterisation - check directly
param_value = self.params[prior_param_name].value
log_prior_probability = prior_function(param_value)
if not np.isfinite(log_prior_probability):
raise ValueError(f"Initial value {param_value} of parameter {prior_param_name} is invalid for prior {prior_function}.")
else:
# This prior is in default parameterisation - need to convert parameter value
# Get the current parameter value and convert to default
default_param_value = self._convert_single_param_to_default(prior_param_name)
log_prior_probability = prior_function(default_param_value)
if not np.isfinite(log_prior_probability):
raise ValueError(f"Initial value {default_param_value} of parameter {prior_param_name} (in default parameterisation) is invalid for prior {prior_function}.")
[docs]
def _convert_single_param_to_default(self, default_param_name: str) -> float:
"""Convert a single parameter from current to default parameterisation."""
# Extract planet letter if this is a planetary parameter
if '_' in default_param_name:
base_param, planet_letter = default_param_name.rsplit('_', 1)
if planet_letter in self.planet_letters:
# Get all current parameters for this planet (we need all five parameters to do a conversion)
planet_params_dict = {}
for par_name in self.parameterisation.pars:
param_key = f"{par_name}_{planet_letter}"
planet_params_dict[par_name] = self.params[param_key].value
# Convert all the planetary parameters to the default parameterisation
default_planet_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params_dict)
# Return just the requested parameter in the default parameterisation
return default_planet_params[base_param]
# For non-planetary parameters (g, gd, gdd, jit), they're the same in all parameterisations
if default_param_name in self.params:
return self.params[default_param_name].value
raise ValueError(f"Cannot convert parameter {default_param_name} to default parameterisation")
[docs]
def _set_hyperpriors_with_validation(self, new_hyperpriors: dict[str, Callable[[float], float]]) -> None:
"""Set hyperpriors with validation."""
# Create merged hyperpriors dict
merged_hyperpriors_dict = dict(self._hyperpriors)
merged_hyperpriors_dict.update(new_hyperpriors)
provided_hyperprior_param_names = set(merged_hyperpriors_dict.keys())
# Check that hyperpriors are provided for all free hyperparameters
free_hyperparam_names = set(self.free_hyperparams_names)
missing_hyperpriors = free_hyperparam_names - provided_hyperprior_param_names
if missing_hyperpriors:
raise ValueError(f"Missing hyperpriors for hyperparameters: {missing_hyperpriors}")
# Check for unexpected hyperpriors
unexpected_hyperprior_param_names = provided_hyperprior_param_names - free_hyperparam_names
if unexpected_hyperprior_param_names:
raise ValueError(
f"Unexpected hyperpriors supplied for hyperparameters: {unexpected_hyperprior_param_names}. "
f"Hyperpriors expected only for hyperparameters: {free_hyperparam_names}"
)
# Check hyperparameter values work with hyperpriors
self._check_hyperparams_values_against_hyperpriors(merged_hyperpriors_dict, free_hyperparam_names)
# Update the hyperpriors with the new values
self._hyperpriors.update(new_hyperpriors)
[docs]
def _check_hyperparams_values_against_hyperpriors(self, validated_hyperpriors: dict[str, Callable[[float], float]], current_free_hyperparam_names: list[str]) -> None:
"""Check hyperparameter values against hyperpriors."""
for hyperprior_param_name, hyperprior_function in validated_hyperpriors.items():
hyperparam_value = self.hyperparams[hyperprior_param_name].value
log_hyperprior_probability = hyperprior_function(hyperparam_value)
if not np.isfinite(log_hyperprior_probability):
raise ValueError(f"Initial value {hyperparam_value} of hyperparameter {hyperprior_param_name} is invalid for hyperprior {hyperprior_function}.")
@property
def free_params_dict(self) -> Dict[str, Parameter]:
"""Free parameters as dict."""
free_pars = {}
for par in self.params:
if self.params[par].fixed is False:
free_pars[par] = self.params[par]
return free_pars
@property
def free_params_values(self) -> list[float]:
"""Values of free parameters as list."""
return [param.value for param in self.free_params_dict.values()]
@property
def free_params_names(self) -> list[str]:
"""Names of free parameters as list."""
return list(self.free_params_dict.keys())
@property
def fixed_params_dict(self) -> Dict[str, Parameter]:
"""Fixed parameters as dict, mapping names to Parameter objects."""
fixed_pars = {}
for par in self.params:
if self.params[par].fixed is True:
fixed_pars[par] = self.params[par]
return fixed_pars
@property
def fixed_params_values(self) -> list[float]:
"""Values of fixed parameters, as list."""
return [param.value for param in self.fixed_params_dict.values()]
@property
def fixed_params_names(self) -> list[str]:
"""Names of fixed parameters, as list."""
return list(self.fixed_params_dict.keys())
@property
def fixed_params_values_dict(self) -> Dict[str, float]:
"""Fixed parameters as dict mapping names to just the values."""
return dict(zip(self.fixed_params_names, self.fixed_params_values))
@property
def free_hyperparams_dict(self) -> Dict[str, Parameter]:
"""Free hyperparameters as dict."""
free_hyperpars = {}
for hyperpar in self.hyperparams:
if self.hyperparams[hyperpar].fixed is False:
free_hyperpars[hyperpar] = self.hyperparams[hyperpar]
return free_hyperpars
@property
def free_hyperparams_values(self) -> list[float]:
"""Values of free hyperparameters as list."""
return [hyperparam.value for hyperparam in self.free_hyperparams_dict.values()]
@property
def free_hyperparams_names(self) -> list[str]:
"""Names of free hyperparameters as list."""
return list(self.free_hyperparams_dict.keys())
@property
def fixed_hyperparams_dict(self) -> Dict[str, Parameter]:
"""Fixed hyperparameters as dict, mapping names to Parameter objects."""
fixed_hyperpars = {}
for hyperpar in self.hyperparams:
if self.hyperparams[hyperpar].fixed is True:
fixed_hyperpars[hyperpar] = self.hyperparams[hyperpar]
return fixed_hyperpars
@property
def fixed_hyperparams_values_dict(self) -> Dict[str, float]:
"""Fixed hyperparameters as dict mapping names to just the values."""
return dict(zip(self.fixed_hyperparams_names, self.fixed_hyperparams_values))
@property
def fixed_hyperparams_names(self) -> list[str]:
"""Names of fixed hyperparameters, as list."""
return list(self.fixed_hyperparams_dict.keys())
@property
def fixed_hyperparams_values(self) -> list[float]:
"""Values of fixed hyperparameters, as list."""
return [hyperparam.value for hyperparam in self.fixed_hyperparams_dict.values()]
[docs]
def find_map_estimate(self, method: str = "Powell") -> scipy.optimize.OptimizeResult:
"""Find Maximum A Posteriori (MAP) estimate of parameters and hyperparameters.
Parameters
----------
method : str, optional
Optimization method to use (default: "Powell")
Returns
-------
scipy.optimize.OptimizeResult
The optimization result containing the MAP estimate
Raises
------
Warning
If MAP optimization fails to converge
"""
# Initialize log-posterior object
gp_lp = GPLogPosterior(
self.planet_letters,
self.parameterisation,
self.gp_kernel,
self.priors,
self.hyperpriors,
self.fixed_params_values_dict,
self.fixed_hyperparams_values_dict,
self.free_params_names,
self.free_hyperparams_names,
self.time,
self.vel,
self.velerr,
self.t0,
self.instrument,
self.unique_instruments,
)
# Combine free params and free hyperparams for initial guess
initial_guess = self.free_params_values + self.free_hyperparams_values
if len(initial_guess) == 0:
raise ValueError(
"Cannot run MAP optimisation: no free parameters or hyperparameters to optimise. "
"At least one parameter or hyperparameter must be set as free (fixed=False) before calling find_map_estimate()."
)
# Perform MAP optimization
def negative_log_posterior(*args: float) -> float:
return gp_lp._negative_log_probability_for_MAP(*args)
map_results = minimize(negative_log_posterior, initial_guess, method=method)
if map_results.success is False:
print(map_results)
warnings.warn("MAP did not succeed. Check the initial values of the parameters and hyperparameters, and the prior/hyperprior functions.")
# Split results back into params and hyperparams
n_params = len(self.free_params_names)
param_values = map_results.x[:n_params]
hyperparam_values = map_results.x[n_params:]
# Print results as dictionary (to show params/hyperparams names too)
map_results_dict = dict(zip(self.free_params_names, param_values))
map_hyperresults_dict = dict(zip(self.free_hyperparams_names, hyperparam_values))
print("MAP parameter results:", map_results_dict)
print("MAP hyperparameter results:", map_hyperresults_dict)
# Return the scipy OptimizeResult object so that user can inspect fully if needed
return map_results
[docs]
def generate_initial_walker_positions_random(self, nwalkers: int, verbose: bool = False, max_attempts: int = 1000) -> np.ndarray:
"""Generate random initial walker positions that satisfy priors and are astrophysically valid.
Creates random starting positions for MCMC walkers by sampling from
appropriate distributions based on each parameter's prior type. Ensures
that parameter combinations are astrophysically valid (e.g., eccentricity < 1).
Parameters
----------
nwalkers : int
Number of MCMC walkers to generate positions for
verbose : bool, default False
If True, print walker positions during generation
max_attempts : int, default 1000
Maximum attempts to generate a valid walker position
Returns
-------
np.ndarray
Array of shape (nwalkers, ndim) where ndim is the number of free parameters
+ hyperparameters. Each row represents the starting position for one walker
in the order of free_params_names + free_hyperparams_names.
Raises
------
ValueError
If a prior type is not supported for walker generation or if unable
to generate valid positions after max_attempts
Examples
--------
>>> # Generate positions for 40 walkers
>>> nwalkers = 10 * len(gpfitter.free_params_names + gpfitter.free_hyperparams_names)
>>> initial_positions = gpfitter.generate_initial_walker_positions_random(nwalkers)
>>> gpfitter.run_mcmc(initial_positions, nwalkers, max_steps=2000)
"""
if len(self.free_params_values) + len(self.free_hyperparams_values) == 0:
raise ValueError(
"Cannot generate walker positions: no free parameters or hyperparameters to sample. "
"At least one parameter or hyperparameter must be set as free (fixed=False)."
)
if verbose:
print("Free parameters:", self.free_params_names)
print("Free hyperparameters:", self.free_hyperparams_names)
param_init = []
hyperparam_init = []
for walker_idx in range(nwalkers):
attempts = 0
while attempts < max_attempts:
param_walker_position = []
hyperparam_walker_position = []
# Generate parameter positions
for param_name in self.free_params_names:
# Check if we have a direct prior for this parameter
# (because user may be fitting in a transformed parameterisation, but gave priors in the default parameterisation instead)
if param_name in self.priors:
prior = self.priors[param_name]
if isinstance(prior, ravest.prior.Normal):
param_walker_position.append(np.random.normal(loc=prior.mean, scale=2*prior.std))
elif isinstance(prior, ravest.prior.HalfNormal):
param_walker_position.append(np.abs(np.random.normal(loc=0, scale=2*prior.std)))
elif isinstance(prior, ravest.prior.Uniform):
param_walker_position.append(np.random.uniform(low=prior.lower, high=prior.upper))
elif isinstance(prior, ravest.prior.TruncatedNormal):
param_walker_position.append(np.random.uniform(low=prior.lower, high=prior.upper))
elif isinstance(prior, ravest.prior.Beta):
param_walker_position.append(np.random.uniform(low=0, high=1))
elif isinstance(prior, ravest.prior.EccentricityUniform):
param_walker_position.append(np.random.uniform(low=0, high=prior.upper))
else:
raise ValueError(f"Unsupported prior type for walker generation: {type(prior)}")
else:
# No direct prior for this parameter (this happens if fitting in transformed parameterisation, but prior given in default)
# use current value + small perturbation
centre_val = self.params[param_name].value
# Add small random perturbation (10% of current value + small fixed amount for near-zero values)
perturbation = np.random.normal(0, abs(centre_val) * 0.1 + 0.01)
param_walker_position.append(centre_val + perturbation)
# Generate hyperparameter positions
for hyperparam_name in self.free_hyperparams_names:
if hyperparam_name in self.hyperpriors:
hyperprior = self.hyperpriors[hyperparam_name]
if isinstance(hyperprior, ravest.prior.Normal):
hyperparam_walker_position.append(np.random.normal(loc=hyperprior.mean, scale=hyperprior.std))
elif isinstance(hyperprior, ravest.prior.HalfNormal):
hyperparam_walker_position.append(np.abs(np.random.normal(loc=0, scale=hyperprior.std)))
elif isinstance(hyperprior, ravest.prior.Uniform):
hyperparam_walker_position.append(np.random.uniform(low=hyperprior.lower, high=hyperprior.upper))
elif isinstance(hyperprior, ravest.prior.TruncatedNormal):
hyperparam_walker_position.append(np.random.uniform(low=hyperprior.lower, high=hyperprior.upper))
elif isinstance(hyperprior, ravest.prior.Beta):
hyperparam_walker_position.append(np.random.uniform(low=hyperprior.a, high=hyperprior.b))
elif isinstance(hyperprior, ravest.prior.EccentricityUniform):
hyperparam_walker_position.append(np.random.uniform(low=0, high=hyperprior.upper))
else:
raise ValueError(f"Unsupported hyperprior type for walker generation: {type(hyperprior)}")
# Check astrophysical validity and prior compliance
try:
# Convert walker position to full parameter dict (free + fixed)
free_params_dict = dict(zip(self.free_params_names, param_walker_position))
all_params_dict = self.fixed_params_values_dict | free_params_dict
# Check astrophysical validity
self._validate_astrophysical_validity(all_params_dict)
# Convert hyperparameter position to full hyperparameter dict (free + fixed)
free_hyperparams_dict = dict(zip(self.free_hyperparams_names, hyperparam_walker_position))
all_hyperparams_dict = self.fixed_hyperparams_values_dict | free_hyperparams_dict
# Check hyperparameter validity using GPKernel (internal method for float values)
self.gp_kernel._validate_hyperparams_values(all_hyperparams_dict)
# Check prior compliance using GPLogPosterior
lp = GPLogPosterior(
self.planet_letters,
self.parameterisation,
self.gp_kernel,
self.priors,
self.hyperpriors,
self.fixed_params_values_dict,
self.fixed_hyperparams_values_dict,
self.free_params_names,
self.free_hyperparams_names,
self.time,
self.vel,
self.velerr,
self.t0,
self.instrument,
self.unique_instruments,
)
# Check the log-prior probability is finite (i.e. proposed initial values are within prior bounds)
params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict)
log_prior = lp.log_prior(params_for_prior)
log_hyperprior = lp.log_hyperprior(free_hyperparams_dict)
if not np.isfinite(log_prior):
raise ValueError(f"Outside prior bounds (log_prior = {log_prior})")
if not np.isfinite(log_hyperprior):
raise ValueError(f"Outside hyperprior bounds (log_hyperprior = {log_hyperprior})")
# If all validations pass, we have a valid walker position
break
except ValueError:
# Validation failed, try again
attempts += 1
continue
if attempts >= max_attempts:
raise ValueError(f"Could not generate astrophysically valid walker {walker_idx} after {max_attempts} attempts. "
f"Consider relaxing priors/hyperpriors or checking parameter constraints.")
if verbose:
print(f"Walker {walker_idx} param position: {param_walker_position}")
print(f"Walker {walker_idx} hyperparam position: {hyperparam_walker_position}")
print(f"(valid after {attempts + 1} attempts)")
param_init.append(param_walker_position)
hyperparam_init.append(hyperparam_walker_position)
param_init = np.array(param_init)
hyperparam_init = np.array(hyperparam_init)
# Combine parameter and hyperparameter positions into single array
if hyperparam_init.size > 0:
initial_positions = np.concatenate([param_init, hyperparam_init], axis=1)
else:
initial_positions = param_init
if verbose:
print(f"Generated MCMC initial param positions with shape: {param_init.shape}")
print(f"Generated MCMC initial hyperparam positions with shape: {hyperparam_init.shape}")
print(f"Combined initial positions shape: {initial_positions.shape}")
return initial_positions
[docs]
def generate_initial_walker_positions_around_point(
self,
centre: np.ndarray | list,
nwalkers: int,
scale: float = 1e-4,
relative: bool = True,
verbose: bool = False,
max_attempts: int = 1000
) -> np.ndarray:
"""Generate initial walker positions in a ball around a supplied centre point.
Creates starting positions for MCMC walkers clustered around a centre point
(e.g., MAP estimate). Each walker is generated by adding small random perturbations
to the centre values. Validates that both the centre point and all generated
walker positions satisfy priors/hyperpriors and are astrophysically valid.
Parameters
----------
centre : np.ndarray or list
Centre point for walker positions. Must have length equal to the number
of free parameters + free hyperparameters and be in the order of
free_params_names + free_hyperparams_names.
nwalkers : int
Number of MCMC walkers to generate positions for
scale : float, default 1e-4
Scale of perturbations around centre point
relative : bool, default True
If True, perturbations scale with parameter values (scale * centre * random).
If False, perturbations are absolute (scale * random).
verbose : bool, default False
If True, print walker positions during generation
max_attempts : int, default 1000
Maximum attempts to generate a valid walker position
Returns
-------
np.ndarray
Array of shape (nwalkers, ndim) where ndim is the number of free parameters
+ free hyperparameters. Each row represents the starting position for one
walker in the order of free_params_names + free_hyperparams_names.
Raises
------
ValueError
If centre has wrong length, if centre point is invalid, or if unable
to generate valid positions after max_attempts
Examples
--------
>>> # Generate walkers around MAP estimate
>>> map_result = gpfitter.find_map_estimate()
>>> initial_positions = gpfitter.generate_initial_walker_positions_around_point(
... centre=map_result.x, nwalkers=40, scale=1e-4
... )
>>> gpfitter.run_mcmc(initial_positions, nwalkers=40, max_steps=2000)
"""
if len(self.free_params_values) + len(self.free_hyperparams_values) == 0:
raise ValueError(
"Cannot generate walker positions: no free parameters or hyperparameters to sample. "
"At least one parameter or hyperparameter must be set as free (fixed=False)."
)
centre = np.asarray(centre)
expected_length = len(self.free_params_names) + len(self.free_hyperparams_names)
if len(centre) != expected_length:
raise ValueError(
f"Centre must have length {expected_length} "
f"({len(self.free_params_names)} free params + {len(self.free_hyperparams_names)} free hyperparams), "
f"got {len(centre)}"
)
if verbose:
print("Free parameters:", self.free_params_names)
print("Free hyperparameters:", self.free_hyperparams_names)
print(f"Centre values: {centre}")
# Split centre into params and hyperparams
n_params = len(self.free_params_names)
centre_params = centre[:n_params]
centre_hyperparams = centre[n_params:]
# Validate centre point first
try:
free_params_dict = dict(zip(self.free_params_names, centre_params))
all_params_dict = self.fixed_params_values_dict | free_params_dict
# Check astrophysical validity
self._validate_astrophysical_validity(all_params_dict)
free_hyperparams_dict = dict(zip(self.free_hyperparams_names, centre_hyperparams))
all_hyperparams_dict = self.fixed_hyperparams_values_dict | free_hyperparams_dict
# Check hyperparameter validity
self.gp_kernel._validate_hyperparams_values(all_hyperparams_dict)
# Check prior/hyperprior compliance
lp = GPLogPosterior(
self.planet_letters,
self.parameterisation,
self.gp_kernel,
self.priors,
self.hyperpriors,
self.fixed_params_values_dict,
self.fixed_hyperparams_values_dict,
self.free_params_names,
self.free_hyperparams_names,
self.time,
self.vel,
self.velerr,
self.t0,
self.instrument,
self.unique_instruments,
)
params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict)
log_prior = lp.log_prior(params_for_prior)
log_hyperprior = lp.log_hyperprior(free_hyperparams_dict)
if not np.isfinite(log_prior):
raise ValueError(f"Centre point outside prior bounds (log_prior = {log_prior})")
if not np.isfinite(log_hyperprior):
raise ValueError(f"Centre point outside hyperprior bounds (log_hyperprior = {log_hyperprior})")
if verbose:
print(f"Centre point validated (log_prior = {log_prior}, log_hyperprior = {log_hyperprior})")
except ValueError as e:
raise ValueError(f"Supplied centre point is not valid: {e}")
# Generate walker positions around centre
param_init = []
hyperparam_init = []
if verbose and relative and np.any(centre == 0.0):
all_names = self.free_params_names + self.free_hyperparams_names
zero_names = [all_names[i] for i in range(len(centre)) if centre[i] == 0.0]
print(f"Note: centre value is exactly 0.0 for {zero_names}; "
f"using absolute perturbation (scale={scale}) for these parameters.")
for walker_idx in range(nwalkers):
attempts = 0
while attempts < max_attempts:
# Generate perturbation
random_vals = np.random.randn(len(centre))
if relative:
# Relative perturbation: scales with parameter values.
# When a centre value is exactly 0.0, the relative
# perturbation (scale * randn * |0|) is always zero,
# producing identical walker values in that dimension.
# This causes emcee to reject the walkers as linearly
# dependent (condition number check). Fall back to
# absolute perturbation for those parameters.
perturbation = np.empty(len(centre))
for i in range(len(centre)):
if centre[i] == 0.0:
perturbation[i] = scale * random_vals[i]
else:
perturbation[i] = scale * random_vals[i] * np.abs(centre[i])
else:
# Absolute perturbation: same scale for all parameters
perturbation = scale * random_vals
walker_position = centre + perturbation
walker_params = walker_position[:n_params]
walker_hyperparams = walker_position[n_params:]
# Validate this walker position
try:
free_params_dict = dict(zip(self.free_params_names, walker_params))
all_params_dict = self.fixed_params_values_dict | free_params_dict
# Check astrophysical validity
self._validate_astrophysical_validity(all_params_dict)
free_hyperparams_dict = dict(zip(self.free_hyperparams_names, walker_hyperparams))
all_hyperparams_dict = self.fixed_hyperparams_values_dict | free_hyperparams_dict
# Check hyperparameter validity
self.gp_kernel._validate_hyperparams_values(all_hyperparams_dict)
# Check prior/hyperprior compliance
params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict)
log_prior = lp.log_prior(params_for_prior)
log_hyperprior = lp.log_hyperprior(free_hyperparams_dict)
if not np.isfinite(log_prior):
raise ValueError(f"Outside prior bounds (log_prior = {log_prior})")
if not np.isfinite(log_hyperprior):
raise ValueError(f"Outside hyperprior bounds (log_hyperprior = {log_hyperprior})")
# If validation passes, we have a valid walker position
break
except ValueError:
# Validation failed, try again
attempts += 1
continue
if attempts >= max_attempts:
raise ValueError(
f"Could not generate astrophysically valid walker {walker_idx} after {max_attempts} attempts. "
f"Consider using a larger scale parameter or checking that the centre point is not too close to prior/physical boundaries."
)
if verbose:
print(f"Walker {walker_idx} param position: {walker_params}")
print(f"Walker {walker_idx} hyperparam position: {walker_hyperparams}")
print(f"(valid after {attempts + 1} attempts)")
param_init.append(walker_params)
hyperparam_init.append(walker_hyperparams)
param_init = np.array(param_init)
hyperparam_init = np.array(hyperparam_init)
# Combine parameter and hyperparameter positions into single array
if hyperparam_init.size > 0:
initial_positions = np.concatenate([param_init, hyperparam_init], axis=1)
else:
initial_positions = param_init
if verbose:
print(f"Generated MCMC initial param positions with shape: {param_init.shape}")
print(f"Generated MCMC initial hyperparam positions with shape: {hyperparam_init.shape}")
print(f"Combined initial positions shape: {initial_positions.shape}")
return initial_positions
[docs]
def generate_initial_walker_positions_from_map(
self,
map_result: scipy.optimize.OptimizeResult,
nwalkers: int,
scale: float = 1e-4,
relative: bool = True,
verbose: bool = False,
max_attempts: int = 1000
) -> np.ndarray:
"""Generate initial walker positions around MAP estimate.
Convenience function that generates walker positions clustered around
MAP parameter and hyperparameter estimates from a pre-computed MAP result.
Parameters
----------
map_result : scipy.optimize.OptimizeResult
Result from find_map_estimate()
nwalkers : int
Number of MCMC walkers to generate positions for
scale : float, default 1e-4
Scale of perturbations around MAP values
relative : bool, default True
If True, perturbations scale with parameter values.
If False, perturbations are absolute.
verbose : bool, default False
If True, print walker positions during generation
max_attempts : int, default 1000
Maximum attempts to generate a valid walker position
Returns
-------
np.ndarray
Array of shape (nwalkers, ndim) where ndim is the number of free parameters
+ free hyperparameters. Each row represents the starting position for one walker.
Raises
------
ValueError
If unable to generate valid positions
Examples
--------
>>> # Find MAP then generate walkers around it
>>> map_result = gpfitter.find_map_estimate()
>>> initial_positions = gpfitter.generate_initial_walker_positions_from_map(
... map_result=map_result, nwalkers=40
... )
>>> gpfitter.run_mcmc(initial_positions, nwalkers=40, max_steps=2000)
"""
if len(self.free_params_values) + len(self.free_hyperparams_values) == 0:
raise ValueError(
"Cannot generate walker positions: no free parameters or hyperparameters to sample. "
"At least one parameter or hyperparameter must be set as free (fixed=False)."
)
return self.generate_initial_walker_positions_around_point(
centre=map_result.x,
nwalkers=nwalkers,
scale=scale,
relative=relative,
verbose=verbose,
max_attempts=max_attempts
)
[docs]
def run_mcmc(self, initial_positions: np.ndarray, nwalkers: int, max_steps: int = 5000, progress: bool = True, multiprocessing: bool = False, check_convergence: bool = False, convergence_check_interval: int = 1000, convergence_check_start: int = 0) -> None:
"""Run MCMC sampling from given initial walker positions.
Parameters
----------
initial_positions : np.ndarray
Starting positions for all MCMC walkers. Shape must be (nwalkers, ndim)
where ndim is the number of free parameters + hyperparameters. Each row
represents the starting position for one walker in the order of
free_params_names + free_hyperparams_names.
nwalkers : int
Number of MCMC walkers (must match first dimension of initial_positions)
max_steps : int, optional
Maximum number of MCMC steps to run. If check_convergence=False, runs for
exactly this many steps. If check_convergence=True, runs up to this many
steps, stopping early when convergence criteria are met (default: 5000)
progress : bool, optional
Whether to show progress bar during MCMC (default: True)
multiprocessing : bool, optional
Whether to use multiprocessing for MCMC (default: False)
check_convergence : bool, optional
If True, check for convergence and stop early when criteria met.
Convergence requires: chain length > 50 times max autocorrelation time,
and autocorrelation time estimate stable to 1 percent.
If False, run for exactly max_steps (default: False)
convergence_check_interval : int, optional
Steps between convergence checks (only used if check_convergence=True) (default: 1000)
convergence_check_start : int, optional
Minimum iteration before starting convergence checks. Set this sensibly
(e.g. 2x burn-in) to avoid inaccurate tau estimates on short chains (default: 0)
"""
if len(self.free_params_values) + len(self.free_hyperparams_values) == 0:
raise ValueError(
"Cannot run MCMC: no free parameters or hyperparameters to sample. "
"At least one parameter or hyperparameter must be set as free (fixed=False)."
)
# Initialize log-posterior object for MCMC sampling
gp_lp = GPLogPosterior(
self.planet_letters,
self.parameterisation,
self.gp_kernel,
self.priors,
self.hyperpriors,
self.fixed_params_values_dict,
self.fixed_hyperparams_values_dict,
self.free_params_names,
self.free_hyperparams_names,
self.time,
self.vel,
self.velerr,
self.t0,
self.instrument,
self.unique_instruments,
)
# Enforce minimum number of walkers (though users ideally should have many more than this)
if nwalkers < 2 * self.ndim:
logging.warning(f"nwalkers should be at least 2 * ndim. You have {nwalkers} walkers and {self.ndim} dimensions. Setting nwalkers to {2 * self.ndim}.")
self.nwalkers = 2 * self.ndim
else:
self.nwalkers = nwalkers
# Validate walker positions shape
if initial_positions.shape != (nwalkers, self.ndim):
raise ValueError(f"initial_positions must have shape ({nwalkers}, {self.ndim}), got {initial_positions.shape}")
# Validate every walker position for astrophysical validity and prior compliance
# (we don't want to start any chains in invalid parameter space)
n_params = len(self.free_params_names)
for i, walker_position in enumerate(initial_positions):
# Split walker position into parameters and hyperparameters
param_position = walker_position[:n_params]
hyperparam_position = walker_position[n_params:] if n_params < len(walker_position) else []
walker_params_dict = dict(zip(self.free_params_names, param_position))
walker_hyperparams_dict = dict(zip(self.free_hyperparams_names, hyperparam_position))
all_params_dict = self.fixed_params_values_dict | walker_params_dict
all_hyperparams_dict = self.fixed_hyperparams_values_dict | walker_hyperparams_dict
# Check astrophysical validity
try:
self._validate_astrophysical_validity(all_params_dict)
except ValueError as e:
raise ValueError(f"Walker {i} has invalid astrophysical parameters: {e}") from e
# Check hyperparameter validity using GPKernel (internal method for float values)
try:
self.gp_kernel._validate_hyperparams_values(all_hyperparams_dict)
except ValueError as e:
raise ValueError(f"Walker {i} has invalid hyperparameters: {e}") from e
# Check prior compliance
params_for_prior = gp_lp._convert_params_for_prior_evaluation(walker_params_dict)
log_prior = gp_lp.log_prior(params_for_prior)
if not np.isfinite(log_prior):
raise ValueError(f"Walker {i} is outside prior bounds (log_prior = {log_prior})")
# Check hyperprior compliance
log_hyperprior = gp_lp.log_hyperprior(walker_hyperparams_dict)
if not np.isfinite(log_hyperprior):
raise ValueError(f"Walker {i} is outside hyperprior bounds (log_hyperprior = {log_hyperprior})")
# Combine parameter names for sampler (needs it as one argument)
all_param_names = self.free_params_names + self.free_hyperparams_names
# TODO: parameter_names argument does slightly impact performance - but not sure if it can be avoided, we do need the names
# and I'm not sure constructing the dictionary later ourselves manually is any quicker than passing parameter_names argument
# Create sampler
if multiprocessing:
pool = mp.get_context("spawn").Pool() # Use 'spawn' instead of 'fork' to avoid issues on some Linux platforms
sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, gp_lp.log_probability,
parameter_names=all_param_names,
pool=pool)
else:
sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, gp_lp.log_probability,
parameter_names=all_param_names)
# Warn if convergence arguments provided but convergence checking disabled
if not check_convergence:
if convergence_check_interval != 1000 or convergence_check_start != 0:
logging.warning(
"Convergence checking arguments provided but check_convergence=False. "
"These arguments will be ignored. Did you forget to set check_convergence=True?"
)
# Run MCMC with or without convergence checking
if not check_convergence:
# Fixed-length mode - run for exactly max_steps
logging.info(f"Starting MCMC for {max_steps} steps...")
sampler.run_mcmc(initial_state=initial_positions, nsteps=max_steps, progress=progress)
logging.info("...MCMC done.")
else:
# Convergence checking - run up to max_steps, stopping early if converged
logging.info(f"Starting MCMC with convergence checks. (Maximum {max_steps} steps, checking convergence every {convergence_check_interval} steps after iteration {convergence_check_start})...")
# Initialize autocorrelation history storage
self.autocorr_history = {}
old_tau = np.inf
for sample in sampler.sample(initial_state=initial_positions, iterations=max_steps, progress=progress):
# Only check at specified intervals
if sampler.iteration % convergence_check_interval != 0:
continue
# Don't check before we have reached convergence_check_start
if sampler.iteration < convergence_check_start:
continue
# Get autocorrelation time estimate
tau = sampler.get_autocorr_time(tol=0)
# Store autocorrelation history for plotting/diagnostics later
self.autocorr_history[sampler.iteration] = tau.copy()
# Log progress
logging.info(f"Convergence check: Step {sampler.iteration}: mean(tau)={np.mean(tau):.1f}, max(tau)={np.max(tau):.1f}")
# Check convergence criteria
check_chain_length = np.all(sampler.iteration > 50 * tau) # Chain length > 50 * tau
check_stable_tau = np.all(np.abs(old_tau - tau) / tau < 0.01) # Tau stable to 1 percent
converged = check_chain_length and check_stable_tau
if converged:
logging.info(f"Converged at iteration {sampler.iteration}")
break
else:
logging.info(f"Not yet converged (N/50>tau check: {check_chain_length}, tau stability check: {check_stable_tau})")
# Warn if approaching max steps without convergence
if sampler.iteration > 0.8 * max_steps:
logging.warning(f"Approaching max iterations ({max_steps}) without convergence! (max tau={np.max(tau):.1f}, tau stability change={np.abs(old_tau - tau) / tau})")
# Update old tau for next check
old_tau = tau
# Final log
final_steps = sampler.iteration
logging.info(f"MCMC complete: {final_steps} steps total")
# Close multiprocessing pool if used
if multiprocessing:
pool.close()
pool.join()
self.sampler = sampler
[docs]
def get_samples_np(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, flat: bool = False) -> np.ndarray:
"""Return a contiguous numpy array of MCMC samples.
Samples can be discarded from the start and/or the end of the array. You can
also thin (take only every n-th sample), and you can flatten the array
so that each walker's chain is merged into one chain.
This is the foundational method for accessing MCMC samples. All the other
get_samples methods build on this.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
flat : bool, optional
Whether to flatten each walker's chain into one chain. (default: False)
If True, return flattened array with shape (nsteps_after_discard_thin * nwalkers, ndim)
If False, return unflattened array with shape (nsteps_after_discard_thin, nwalkers, ndim)
Returns
-------
np.ndarray
Contiguous array of MCMC samples. Shape depends on `flat` parameter:
- flat=False: (nsteps_after_discard_thin, nwalkers, ndim)
- flat=True: (nsteps_after_discard_thin * nwalkers, ndim)
Notes
-----
We enforce np.ascontiguousarray() on the return, because np.reshape() does
not guarantee a contiguous array in memory.
"""
# Get the full chain from emcee without any processing
full_samples = self.sampler.get_chain(discard=0, thin=1, flat=False)
# Match emcee's slicing logic: [discard + thin - 1 : end : thin]
# But adapted - we also allow for discarding from the end
start_idx = discard_start + thin - 1
if discard_end == 0:
end_idx = full_samples.shape[0]
else:
end_idx = full_samples.shape[0] - discard_end
# Check the start and end points are valid
if start_idx >= end_idx:
raise ValueError(f"Invalid parameters: start_idx ({start_idx}) >= end_idx ({end_idx}). "
f"Try reducing discard_start ({discard_start}), discard_end ({discard_end}), or thin ({thin}).")
# Apply the slicing
samples = full_samples[start_idx:end_idx:thin]
# Flatten if requested (after discarding) - flatten steps and walkers into single dimension
if flat:
# (steps, walkers, ndim) -> (steps*walkers, ndim)
nsteps, nwalkers, ndim = samples.shape
samples = samples.reshape(nsteps * nwalkers, ndim)
return np.ascontiguousarray(samples)
[docs]
def get_samples_df(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> pd.DataFrame:
"""Return a pandas DataFrame of flattened MCMC samples.
Each row represents one sample, each column represents one parameter or hyperparameter.
Built on get_samples_np().
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
Returns
-------
pd.DataFrame
DataFrame with shape (nsteps_after_discard_thin * nwalkers, ndim).
Columns are parameter and hyperparameter names.
"""
flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
return pd.DataFrame(flat_samples, columns=self.free_params_names + self.free_hyperparams_names)
[docs]
def get_samples_dict(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> Dict[str, np.ndarray]:
"""Return a dict of flattened MCMC samples.
Each parameter and hyperparameter gets a 1D (flattened) contiguous array of all its samples.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
Returns
-------
dict
Dictionary mapping parameter and hyperparameter names to 1D arrays of samples.
Each array has shape (nsteps_after_discard_thin * nwalkers,)
Examples
--------
>>> samples_dict = gpfitter.get_samples_dict(discard_start=1000)
>>> K_b_samples = samples_dict['K_b'] # All samples for parameter K for planet b
>>> gp_amp_samples = samples_dict['gp_amp'] # All samples for GP amplitude
"""
flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
param_names = self.free_params_names + self.free_hyperparams_names
# Direct numpy slicing - much faster than pandas operations
return {name: flat_samples[:, i] for i, name in enumerate(param_names)}
[docs]
def get_sampler_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, flat: bool = False) -> np.ndarray:
"""Returns the log probability at each step of the sampler.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
flat : bool, optional
If True, return flattened array shape (nsteps_after_discard_thin * nwalkers)
If False, return unflattened array shape (nsteps_after_discard_thin, nwalkers) (default: False)
Returns
-------
np.ndarray
Array of log probabilities of the function at each sample.
"""
# Get the full log prob chain from emcee without any processing
full_lnprob = self.sampler.get_log_prob(discard=0, thin=1, flat=False)
# Match emcee's slicing logic: [discard + thin - 1 : end : thin]
# But adapted - we also allow for discarding from the end
start_idx = discard_start + thin - 1
if discard_end == 0:
end_idx = full_lnprob.shape[0]
else:
end_idx = full_lnprob.shape[0] - discard_end
# Check the start and end points are valid
if start_idx >= end_idx:
raise ValueError(f"Invalid parameters: start_idx ({start_idx}) >= end_idx ({end_idx}). "
f"Try reducing discard_start ({discard_start}), discard_end ({discard_end}), or thin ({thin}).")
# Apply the slicing
lnprob = full_lnprob[start_idx:end_idx:thin]
# Flatten if requested (after discarding) - flatten steps and walkers into single dimension
if flat:
# (steps, walkers) -> (steps*walkers,)
nsteps, nwalkers = lnprob.shape
lnprob = lnprob.reshape(nsteps * nwalkers)
return np.ascontiguousarray(lnprob)
[docs]
def get_mcmc_posterior_dict(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> dict:
"""Return dict combining fixed parameter values, and MCMC samples for the free ones.
This method creates a unified dictionary containing all model parameters:
fixed parameters as single float values, and free parameters as arrays
of MCMC samples. This format is ideal for functions like calculate_mpsini
that need all parameters (whether free or fixed), and that should propagate uncertainties from
the free parameters samples.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
Returns
-------
dict
Dictionary of all parameters and hyperparameters:
- Fixed parameters: single float values
- Free parameters: 1D arrays of MCMC samples with shape (nsteps_after_discard_thin * nwalkers,)
- Fixed hyperparameters: single float values
- Free hyperparameters: 1D arrays of MCMC samples with shape (nsteps_after_discard_thin * nwalkers,)
"""
fixed_params_dict = self.fixed_params_values_dict
fixed_hyperparams_dict = self.fixed_hyperparams_values_dict
free_samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin)
return fixed_params_dict | fixed_hyperparams_dict | free_samples_dict
[docs]
def calculate_log_likelihood(self, params_hyperparams_dict: Dict[str, float]) -> float:
"""Calculate log-likelihood for given parameter and hyperparameter values.
Note this does not include (log-)prior probabilities, this is just the
(log-) *likelihood* primarily for use in AICc & BIC calculation.
Parameters
----------
params_hyperparams_dict : dict
Dictionary of all parameter and hyperparameter values (both fixed and free)
Returns
-------
float
The log-likelihood value
"""
# Create GPLogLikelihood object (same as in find_map_estimate and run_mcmc)
gp_log_likelihood = GPLogLikelihood(
time=self.time,
vel=self.vel,
velerr=self.velerr,
t0=self.t0,
instrument=self.instrument,
unique_instruments=self.unique_instruments,
planet_letters=self.planet_letters,
parameterisation=self.parameterisation,
gp_kernel=self.gp_kernel,
)
# Split combined dict into params and hyperparams
all_param_names = self.free_params_names + self.fixed_params_names
all_hyperparam_names = self.free_hyperparams_names + self.fixed_hyperparams_names
params = {name: params_hyperparams_dict[name] for name in all_param_names}
hyperparams = {name: params_hyperparams_dict[name] for name in all_hyperparam_names}
return gp_log_likelihood(params=params, hyperparams=hyperparams)
[docs]
def build_params_dict(self, free_params_hyperparams: np.ndarray | list | Dict[str, float]) -> Dict[str, float]:
"""Build complete parameter dictionary by combining free and fixed parameters and hyperparameters.
Takes free parameter and hyperparameter values from various sources (MAP results,
MCMC samples, or custom values) and combines them with the fixed parameter and
hyperparameter values to create a complete dictionary suitable for calculating
log-likelihood, chi2, AICc, and BIC.
Parameters
----------
free_params_hyperparams : list, np.ndarray, or dict
Free parameter and hyperparameter values from any source:
- list/array: values in order of self.free_params_names + self.free_hyperparams_names
- dict: mapping of free param/hyperparam names to values
Returns
-------
Dict[str, float]
Complete parameters and hyperparameters dict with both free and fixed values
Examples
--------
>>> # From MAP optimization result
>>> map_result = gpfitter.find_map_estimate()
>>> params = gpfitter.build_params_dict(map_result.x)
>>> aicc = gpfitter.calculate_aicc(params)
>>>
>>> # From best MCMC sample
>>> best_sample = gpfitter.get_sample_with_best_lnprob(discard_start=1000)
>>> params = gpfitter.build_params_dict(best_sample)
>>> bic = gpfitter.calculate_bic(params)
>>>
>>> # From custom array (params then hyperparams, in order of names)
>>> custom_values = [5.0, 50.0, 0.1, 0.0, 2450000.0, # params
... 10.0, 5.0, 0.5, 30.0] # hyperparams
>>> params = gpfitter.build_params_dict(custom_values)
>>> log_like = gpfitter.calculate_log_likelihood(params)
"""
if isinstance(free_params_hyperparams, dict):
# Validate that all expected free parameters and hyperparameters are present
expected_params = set(self.free_params_names)
expected_hyperparams = set(self.free_hyperparams_names)
expected_names = expected_params | expected_hyperparams
provided_names = set(free_params_hyperparams.keys())
missing = expected_names - provided_names
if missing:
raise ValueError(f"Missing required free parameters/hyperparameters: {missing}")
extra = provided_names - expected_names
if extra:
raise ValueError(f"Unexpected parameters/hyperparameters provided: {extra}")
return self.fixed_params_values_dict | self.fixed_hyperparams_values_dict | free_params_hyperparams
else:
# Validate that array/list has correct length
expected_length = len(self.free_params_names) + len(self.free_hyperparams_names)
if len(free_params_hyperparams) != expected_length:
raise ValueError(
f"Expected {expected_length} free parameter and hyperparameter values "
f"but got {len(free_params_hyperparams)} "
f"(expecting values for {self.free_params_names} + {self.free_hyperparams_names})"
)
all_free_names = self.free_params_names + self.free_hyperparams_names
free_dict = dict(zip(all_free_names, free_params_hyperparams))
return self.fixed_params_values_dict | self.fixed_hyperparams_values_dict | free_dict
[docs]
@staticmethod
@jax.jit
def _compute_gp_chi2(
kernel: kernels.Kernel,
time_array: jnp.ndarray,
verr_squared_array: jnp.ndarray,
residuals: jnp.ndarray
) -> float:
"""JIT-compiled GP chi-squared computation.
Calculates chi^2 = r^T @ K^(-1) @ r where K is the full covariance matrix.
Parameters
----------
kernel : tinygp.kernels.Kernel
GP kernel
time_array : jnp.ndarray
Observation times
verr_squared_array : jnp.ndarray
Observational variances (error^2 + jitter^2)
residuals : jnp.ndarray
Data - mean_model
Returns
-------
float
Chi-squared value incorporating full covariance structure
Notes
-----
This uses the efficient Cholesky-based method. The mathematically
equivalent (but much slower) direct computation would be:
# K = kernel(time_array, time_array) + jnp.diag(verr_squared_array)
# K_inv = jnp.linalg.inv(K)
# chi2 = jnp.dot(residuals, jnp.dot(K_inv, residuals))
Instead, we use the Cholesky decomposition K = L @ L^T:
- alpha = L^(-1) @ residuals (via _get_alpha)
- chi^2 = alpha^T @ alpha = r^T @ K^(-1) @ r
"""
gp = GaussianProcess(kernel=kernel, X=time_array, diag=verr_squared_array)
alpha = gp._get_alpha(residuals) # L^(-1) @ residuals, where K = L @ L^T
return jnp.dot(alpha, alpha) # r^T @ K^(-1) @ r = alpha^T @ alpha
[docs]
def calculate_chi2(self, params_hyperparams_dict: Dict[str, float]) -> float:
r"""Calculate chi-squared for given parameter and hyperparameter values.
For GP models, this calculates:
.. math::
\chi^2 = \mathbf{r}^T \mathbf{K}^{-1} \mathbf{r}
where :math:`\mathbf{K}` is the full covariance matrix including
GP kernel and observational uncertainties. This properly accounts
for correlated noise structure.
Uses GPLogLikelihood._calculate_mean_model to avoid code duplication.
Parameters
----------
params_hyperparams_dict : dict
Dictionary of all parameter and hyperparameter values (both fixed and free)
Returns
-------
float
Chi-squared value
"""
# Create GPLogLikelihood instance to reuse mean model calculation
gp_ll = GPLogLikelihood(
time=self.time,
vel=self.vel,
velerr=self.velerr,
t0=self.t0,
instrument=self.instrument,
unique_instruments=self.unique_instruments,
planet_letters=self.planet_letters,
parameterisation=self.parameterisation,
gp_kernel=self.gp_kernel
)
# Extract parameter and hyperparameter dicts
all_param_names = self.free_params_names + self.fixed_params_names
params = {name: params_hyperparams_dict[name] for name in all_param_names}
all_hyperparam_names = self.free_hyperparams_names + self.fixed_hyperparams_names
hyperparams = {name: params_hyperparams_dict[name] for name in all_hyperparam_names}
# Calculate mean model using GPLogLikelihood method
mean_model = gp_ll._calculate_mean_model(params)
# Calculate residuals
residuals = gp_ll.jax_vel - mean_model
# Build GP kernel with hyperparameters
kernel = self.gp_kernel.build_kernel(hyperparams)
# Add jitter to observational uncertainties
# jit_value = params["jit"]
# jit2_verr2 = gp_ll.jax_verr**2 + jit_value**2
velerr_jitter_squared = np.zeros_like(self.velerr)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit = params_hyperparams_dict[f"jit_{inst}"]
velerr_jitter_squared[mask] = self.velerr[mask]**2 + jit**2
# Calculate chi^2 = r^T K^(-1) r using full covariance matrix
return float(self._compute_gp_chi2(kernel, gp_ll.jax_time, velerr_jitter_squared, residuals))
[docs]
def calculate_aicc(self, params_hyperparams_dict: Dict[str, float]) -> float:
r"""Calculate corrected Akaike Information Criterion (AICc).
.. math::
\text{AICc} = 2k - 2\ln\mathcal{L} + \frac{2k^2 + 2k}{n - k - 1}
where :math:`k` is the number of free parameters and hyperparameters,
:math:`n` is the number of observations, and :math:`\mathcal{L}` is
the likelihood. Converges to AIC for large :math:`n`.
Parameters
----------
params_hyperparams_dict : dict
Dictionary of all parameter and hyperparameter values (both fixed and free)
Returns
-------
float
AICc value
"""
k = self.ndim
n = len(self.time)
log_like = self.calculate_log_likelihood(params_hyperparams_dict)
aic = 2 * k - 2 * log_like # traditional AIC
correction = (2 * k**2 + 2 * k) / (n - k - 1) # small-sample correction
return aic + correction
[docs]
def calculate_bic(self, params_hyperparams_dict: Dict[str, float]) -> float:
r"""Calculate Bayesian Information Criterion (BIC) for given parameters and hyperparameters.
.. math::
\text{BIC} = k \ln n - 2 \ln \mathcal{L}
where :math:`k` is the number of free parameters and hyperparameters,
:math:`n` is the number of observations, and :math:`\mathcal{L}` is
the likelihood.
Parameters
----------
params_hyperparams_dict : dict
Dictionary of all parameter and hyperparameter values (both fixed and free)
Returns
-------
float
BIC value
"""
log_like = self.calculate_log_likelihood(params_hyperparams_dict)
return self.ndim * np.log(len(self.time)) - 2 * log_like
[docs]
def get_sample_with_best_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> Dict[str, float]:
"""Get parameter and hyperparameter values from the MCMC sample with the highest log probability.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
Returns
-------
Dict[str, float]
Dictionary of parameter and hyperparameter names to values from the best sample
"""
# Get samples and log probabilities
samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
lnprob = self.get_sampler_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
# Find index of maximum log probability
best_idx = np.argmax(lnprob)
best_lnprob = lnprob[best_idx]
print(f"Best sample found with log probability {best_lnprob:.6f} at index {best_idx} of samples (with discard_start={discard_start}, discard_end={discard_end}, thin={thin})")
# Get parameter values at that index
best_values = samples[best_idx]
# Return as dictionary (samples include both params and hyperparams)
all_param_names = self.free_params_names + self.free_hyperparams_names
return dict(zip(all_param_names, best_values))
[docs]
def plot_autocorr_estimates(
self,
params: list[str] | None = None,
hyperparams: list[str] | None = None,
plot_mean: bool = False,
show_legend: bool = True,
title: str | None = "Autocorrelation Time Estimates",
xlabel: str | None = "Step number",
ylabel: str | None = r"Autocorrelation time $\tau$",
save: bool = False,
fname: str = "autocorr_plot.png",
dpi: int = 100
) -> None:
r"""Plot autocorrelation time estimates from adaptive MCMC run.
Shows how autocorrelation time evolved during the MCMC run and
the convergence threshold line (N / 50).
Only available if run_mcmc was called with check_convergence=True.
Parameters
----------
params : list[str] or None, optional
List of parameter names to plot. If None, plots all free parameters (default: None)
hyperparams : list[str] or None, optional
List of hyperparameter names to plot. If None, plots all free hyperparameters (default: None)
plot_mean : bool, optional
If True, plot mean tau instead of individual parameter/hyperparameter taus.
Overrides params and hyperparams arguments (default: False)
show_legend : bool, optional
Whether to show legend (default: True)
title : str or None, optional
Plot title (default: "Autocorrelation Time Estimates"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label (default: "Step number"). Set to None or "" to skip.
ylabel : str or None, optional
Y-axis label (default: r"Autocorrelation time $\tau$"). Set to None or "" to skip.
save : bool, optional
Save the plot to path `fname` (default: False)
fname : str, optional
The path to save the plot to (default: "autocorr_plot.png")
dpi : int, optional
The dpi to save the image at (default: 100)
Raises
------
ValueError
If no autocorrelation history is available (run_mcmc was not called
with check_convergence=True, or has not been called yet)
"""
# Check if data available
if not hasattr(self, 'autocorr_history') or len(self.autocorr_history) == 0:
raise ValueError(
"No autocorrelation history available. "
"Please run run_mcmc() with check_convergence=True first."
)
iterations = np.array(list(self.autocorr_history.keys()))
max_iteration = np.max(iterations)
tau_history = np.array(list(self.autocorr_history.values())) # Shape: (n_checks, n_params + n_hyperparams)
# Create plot
fig, ax = plt.subplots(1, figsize=(10, 6))
if title:
fig.suptitle(title)
# Plot convergence threshold (N/50)
ax.plot([0, max_iteration], [0, max_iteration / 50], "--k", linewidth=2,
label="N/50")
if plot_mean:
# Plot mean tau
mean_tau = np.mean(tau_history, axis=1)
ax.plot(iterations, mean_tau, linewidth=2, label="Mean τ")
else:
# Determine which parameters/hyperparameters to plot
names_to_plot = []
indices_to_plot = []
# Handle parameters
if params is None:
# Include all free params
for i, param_name in enumerate(self.free_params_names):
names_to_plot.append(param_name)
indices_to_plot.append(i)
else:
# Include only specified params
for param in params:
if param in self.free_params_names:
idx = self.free_params_names.index(param)
names_to_plot.append(param)
indices_to_plot.append(idx)
else:
logging.warning(f"Parameter '{param}' not found in free parameters, skipping")
# Handle hyperparameters
if hyperparams is None:
# Include all free hyperparams
for i, hyperparam_name in enumerate(self.free_hyperparams_names):
names_to_plot.append(hyperparam_name)
indices_to_plot.append(len(self.free_params_names) + i)
else:
# Include only specified hyperparams
for hyperparam in hyperparams:
if hyperparam in self.free_hyperparams_names:
idx = len(self.free_params_names) + self.free_hyperparams_names.index(hyperparam)
names_to_plot.append(hyperparam)
indices_to_plot.append(idx)
else:
logging.warning(f"Hyperparameter '{hyperparam}' not found in free hyperparameters, skipping")
# Plot individual parameter/hyperparameter taus
for idx, name in zip(indices_to_plot, names_to_plot):
ax.plot(iterations, tau_history[:, idx], alpha=0.7, label=param_key_to_latex(name))
ax.set_xlim(0, iterations.max())
ax.set_ylim(bottom=0)
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
if show_legend:
ax.legend(loc='upper left')
ax.grid(True, alpha=0.3)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def plot_chains(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, truths: list = None, title: str | None = "Chains plot", xlabel: str | None = "Step number", save: bool = False, fname: str = "chains_plot.png", dpi: int = 100) -> None:
"""Plot MCMC chains for all free parameters and hyperparameters.
Displays the evolution of each free parameter and hyperparameter across MCMC steps
for all walkers. Useful for diagnosing convergence, burn-in, and mixing of the
MCMC chains. Each parameter/hyperparameter gets its own subplot showing all walker traces.
For GP fitting, this includes both planetary/trend parameters and GP kernel hyperparameters.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
truths : list, optional
List of true parameter/hyperparameter values to overplot as horizontal lines.
Must match the number of free parameters + hyperparameters. Use None for
parameters without known truth values (default: None)
title : str or None, optional
Plot title (default: "Chains plot"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label (default: "Step number"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "chains_plot.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
# Scale figure height to maintain consistent subplot size
subplot_height_inches = 1.25
fig, axes = plt.subplots(self.ndim, figsize=(10, self.ndim * subplot_height_inches),
sharex=True, constrained_layout=True)
if title:
fig.suptitle(title)
if self.ndim == 1:
axes = [axes]
if truths is not None:
if not len(truths) == self.ndim:
raise ValueError(f"Length of truths ({len(truths)}) must match number of free parameters and hyperparameters ({self.ndim})")
samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=False)
param_names = self.free_params_names + self.free_hyperparams_names
for i in range(self.ndim):
ax = axes[i]
to_plot = samples[:, :, i]
ax.plot(to_plot, "k", alpha=0.3)
ax.set_xlim(0, len(samples))
ax.set_ylabel(param_key_to_latex(param_names[i]))
if truths is not None and truths[i] is not None:
ax.axhline(truths[i], color="tab:blue")
fig.align_ylabels(axes)
if xlabel:
axes[-1].set_xlabel(xlabel)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def plot_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = "Log Probability Traces", xlabel: str | None = "Step number", ylabel: str | None = "Log probability", save: bool = False, fname: str = "lnprob_plot.png", dpi: int = 100) -> None:
"""Plot log probability traces for all walkers.
Useful for diagnosing MCMC convergence and identifying problematic
walkers/parameters. You can use `discard_start` and `discard_end` to
focus in on specific steps in the chains.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
title : str or None, optional
Plot title (default: "Log Probability Traces"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label (default: "Step number"). Set to None or "" to skip.
ylabel : str or None, optional
Y-axis label (default: "Log probability"). Set to None or "" to skip.
save : bool, optional
Save the plot to path `fname` (default: False)
fname : str, optional
The path to save the plot to (default: "lnprob_plot.png")
dpi : int, optional
The dpi to save the image at (default: 100)
"""
fig, ax = plt.subplots(1, figsize=(10, 6))
if title:
fig.suptitle(title)
lnprobs = self.get_sampler_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=False)
nsteps, nwalkers = lnprobs.shape
for i in range(nwalkers):
to_plot = lnprobs[:, i]
ax.plot(to_plot, "k", alpha=0.3)
ax.set_xlim(0, nsteps)
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def plot_corner(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, plot_datapoints: bool = False, truths: list[float] = None, title: str | None = "Corner plots", save: bool = False, fname: str = "corner_plot.png", dpi: int = 100) -> None:
"""Create a corner plot of MCMC samples.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
plot_datapoints : bool, optional
Show individual data points in addition to contours (default: False)
truths : list of float, optional
True parameter values to overplot as vertical/horizontal lines (default: None).
Must match the order of free parameters if provided.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "corner_plot.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
param_labels = [param_key_to_latex(n) for n in self.free_params_names + self.free_hyperparams_names]
fig = corner.corner(
flat_samples, labels=param_labels, show_titles=True,
plot_datapoints=plot_datapoints, quantiles=[0.1585, 0.5, 0.8415],
truths=truths,
)
fig.suptitle("Corner plots")
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def _plot_rv(self, params_hyperparams: Dict[str, float], title: str = "RV Model", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "rv_plot.png", dpi: int = 100) -> None:
"""Helper function to plot RV model with given parameters.
For GP fitting, this plots both the mean model (planets + trend)
and the GP component. The GP component is predicted at smooth time points
using the posterior conditioned on observed data residuals, including
1-sigma uncertainty bands.
Parameters
----------
params : dict
Dictionary of parameter values (both free and fixed)
title : str, optional
Plot title (default: "RV Model"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "rv_plot.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
# Step 1: RV at smooth time points, for plotting
_tmin, _tmax = self.time.min(), self.time.max()
_trange = _tmax - _tmin
tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000)
# Calculate mean model (planets + trend) at smooth time points (no gamma - that's per-instrument)
rv_mean_smooth = np.zeros(len(tsmooth))
# Add each Planet's RV at smooth times
for letter in self.planet_letters:
planet_params = {}
for par_name in self.parameterisation.pars:
key = f"{par_name}_{letter}"
planet_params[par_name] = params_hyperparams[key]
planet = ravest.model.Planet(letter, self.parameterisation, planet_params)
rv_mean_smooth += planet.radial_velocity(tsmooth)
# Add the system Trend at smooth times (gd, gdd only - no gamma)
trend_params = {"gd": params_hyperparams["gd"], "gdd": params_hyperparams["gdd"]}
trend = ravest.model.Trend(params=trend_params, t0=self.t0)
rv_trend_smooth = trend.radial_velocity(tsmooth)
rv_mean_smooth += rv_trend_smooth
# Step 2: RV at observed times only, for GP conditioning
# Calculate mean model (planets + trend) at observed times (no gamma)
rv_mean_obs = np.zeros(len(self.time))
# Add each Planet's RV at observed times
for letter in self.planet_letters:
planet_params = {}
for par_name in self.parameterisation.pars:
key = f"{par_name}_{letter}"
planet_params[par_name] = params_hyperparams[key]
planet = ravest.model.Planet(letter, self.parameterisation, planet_params)
rv_mean_obs += planet.radial_velocity(self.time)
# Add the system Trend at observed times
rv_mean_obs += trend.radial_velocity(self.time)
# Subtract per-instrument gamma offsets from data (so residuals are gamma-corrected)
vel_corrected = self.vel.copy()
for inst in self.unique_instruments:
mask = (self.instrument == inst)
vel_corrected[mask] -= params_hyperparams[f"g_{inst}"]
# Step 3: Set up GP for prediction
hyperparams = {hp: params_hyperparams[hp] for hp in self.gp_kernel.expected_hyperparams}
kernel = self.gp_kernel.build_kernel(hyperparams)
# Calculate per-instrument jitter for GP diagonal
jit2_verr2 = np.zeros(len(self.time))
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit = params_hyperparams[f"jit_{inst}"]
jit2_verr2[mask] = self.velerr[mask]**2 + jit**2
# Create GP conditioned on gamma-corrected residuals
residuals_obs = vel_corrected - rv_mean_obs
gp_obs = GaussianProcess(kernel=kernel, X=jnp.array(self.time), diag=jnp.array(jit2_verr2))
# Condition GP on observed residuals and predict at smooth time points
_, gp_cond_smooth = gp_obs.condition(y=jnp.array(residuals_obs), X_test=jnp.array(tsmooth))
rv_gp_mean_smooth = gp_cond_smooth.mean
rv_gp_std_smooth = np.sqrt(gp_cond_smooth.variance)
_, gp_cond_obs = gp_obs.condition(y=jnp.array(residuals_obs), X_test=jnp.array(self.time))
rv_gp_mean_obs = gp_cond_obs.mean
np.sqrt(gp_cond_obs.variance)
# Step 4: generate the plot
# Total model is mean + GP prediction
rv_total_smooth = rv_mean_smooth + rv_gp_mean_smooth
# Calculate per-instrument jitter for error bars
velerr_with_jit = np.zeros_like(self.velerr)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit = params_hyperparams[f"jit_{inst}"]
velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit**2)
# Calculate residuals for residuals subplot (gamma-corrected data - model - GP)
residuals = vel_corrected - (rv_mean_obs + rv_gp_mean_obs)
# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 5),
gridspec_kw={'height_ratios': [3, 1], 'hspace': 0})
# Set up per-instrument colours
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)}
# Plot observed data with error bars (gamma-corrected, coloured by instrument)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=self.velerr[mask],
marker=".", color=inst_colors[inst], ecolor=inst_colors[inst],
linestyle="None", markersize=8, zorder=6, label=inst)
# Jitter extension (faded, no label)
ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=velerr_with_jit[mask],
marker="None", ecolor=inst_colors[inst], linestyle="None",
alpha=0.5, zorder=5)
# Plot mean model component
ax1.plot(tsmooth, rv_mean_smooth, label="Mean Model (Planets + Trend)", color="red", zorder=3)
# Plot GP component
# (We add the trend component so that the planet and GP signals start at the same baseline/gamma RV)
ax1.plot(tsmooth, rv_gp_mean_smooth+rv_trend_smooth, color="blue", label='GP Component', zorder=2)
# Plot total model (Planet + GP)
ax1.plot(tsmooth, rv_total_smooth, label='Total Model (Planet + Trend + GP)', color="black", zorder=4)
# Add 1-sigma uncertainty band around total model
ax1.fill_between(tsmooth, rv_total_smooth - rv_gp_std_smooth, rv_total_smooth + rv_gp_std_smooth,
color='darkgrey', zorder=1)
ax1.set_xlim(tsmooth[0], tsmooth[-1])
if ylabel_main:
ax1.set_ylabel(ylabel_main)
if title:
ax1.set_title(title)
legend = ax1.legend(loc="upper right")
legend.set_zorder(7)
ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in') # Remove x-axis labels from top plot
ax1.tick_params(axis='y', direction='in')
# Set y-axis ticks automatically based on data range
ax1.yaxis.set_major_locator(AutoLocator())
ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax1.tick_params(axis='y', which='minor', direction='in', length=3)
# Residuals subplot (coloured by instrument)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
ax2.errorbar(self.time[mask], residuals[mask], yerr=self.velerr[mask],
marker=".", color=inst_colors[inst], ecolor=inst_colors[inst],
linestyle="None", markersize=8, zorder=6)
ax2.errorbar(self.time[mask], residuals[mask], yerr=velerr_with_jit[mask],
marker="None", ecolor=inst_colors[inst], linestyle="None",
alpha=0.5, zorder=5)
ax2.axhline(0, color="black", linestyle="--", zorder=2)
ax2.set_xlim(tsmooth[0], tsmooth[-1])
# Set symmetric y-limits for residuals
max_abs_residual = np.max(np.abs(residuals + velerr_with_jit))
ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1)
if xlabel:
ax2.set_xlabel(xlabel)
if ylabel_residuals:
ax2.set_ylabel(ylabel_residuals)
ax2.tick_params(axis='x', direction='in')
ax2.tick_params(axis='y', direction='in')
ax2.tick_params(axis='x', top=True, labeltop=False) # Add ticks on shared border
# Set y-axis ticks automatically based on residuals range
ax2.yaxis.set_major_locator(AutoLocator())
ax2.yaxis.set_minor_locator(AutoMinorLocator())
ax2.tick_params(axis='y', which='minor', direction='in', length=3)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def _plot_phase(self, planet_letter: str, params_hyperparams: Dict[str, float], title: str = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "phase_plot.png", dpi: int = 100) -> None:
"""Helper function to plot phase-folded RV model for a single planet with given parameters.
For GP fitting, this handles the challenge that the GP component cannot be
easily separated per planet. The approach is to show the target planet signal
after subtracting other planets and trends, but keeping the GP component
as part of the "data" being phase-folded.
Parameters
----------
planet_letter : str
Letter identifying the planet to plot (e.g., 'b', 'c', 'd')
params : dict
Dictionary of parameter values (both free and fixed)
title : str, optional
Plot title (default: f"Planet {planet_letter} Phase Plot"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "phase_plot.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
if title is None:
title = f"Planet {planet_letter} Phase Plot"
# get smooth linear time curve for plotting
_tmin, _tmax = self.time.min(), self.time.max()
_trange = _tmax - _tmin
tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000)
# Calculate per-instrument jitter for error bars
velerr_with_jit = np.zeros_like(self.velerr)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit = params_hyperparams[f"jit_{inst}"]
velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit**2)
# Get period and time of conjunction for this planet
P = params_hyperparams[f"P_{planet_letter}"]
# Convert to tc if needed
if "Tc" in self.parameterisation.pars:
Tc = params_hyperparams[f"Tc_{planet_letter}"]
elif "e" in self.parameterisation.pars and "w" in self.parameterisation.pars:
_e = params_hyperparams[f"e_{planet_letter}"]
_w = params_hyperparams[f"w_{planet_letter}"]
_Tp = params_hyperparams[f"Tp_{planet_letter}"]
Tc = self.parameterisation.convert_tp_to_tc(_Tp, P, _e, _w)
else:
# Fall back to default parameterisation conversion
planet_params = {par: params_hyperparams[f"{par}_{planet_letter}"] for par in self.parameterisation.pars}
default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params)
Tc = self.parameterisation.convert_tp_to_tc(default_params["Tp"], P, default_params["e"], default_params["w"])
# Create smooth time curve for plotting the model
_tmin, _tmax = self.time.min(), self.time.max()
_trange = _tmax - _tmin
tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000)
# Phase fold both data and smooth times
t_fold_sorted, inds = ravest.model.fold_time_series(self.time, P, Tc)
tsmooth_fold_sorted, smooth_inds = ravest.model.fold_time_series(tsmooth, P, Tc)
# Calculate RV contribution from this planet only (for model curve)
planet_params = {par: params_hyperparams[f"{par}_{planet_letter}"] for par in self.parameterisation.pars}
planet = ravest.model.Planet(planet_letter, self.parameterisation, planet_params)
planet_rv_obs = planet.radial_velocity(self.time)
planet_rv_smooth = planet.radial_velocity(tsmooth)
planet_rv_smooth_sorted = planet_rv_smooth[smooth_inds]
# Subtract per-instrument gamma offsets from data
vel_corrected = self.vel.copy()
for inst in self.unique_instruments:
mask = (self.instrument == inst)
vel_corrected[mask] -= params_hyperparams[f"g_{inst}"]
# Now we need to calculate all the other RV, that needs to be subtracted from the observed data
# All others planets + system Trend + GP mean
other_rv_obs = np.zeros(len(self.time))
# Calculate all other Planets at observed times
for other_letter in self.planet_letters:
if other_letter != planet_letter:
other_params = {par: params_hyperparams[f"{par}_{other_letter}"] for par in self.parameterisation.pars}
other_planet = ravest.model.Planet(other_letter, self.parameterisation, other_params)
other_rv_obs += other_planet.radial_velocity(self.time)
# Calculate trend (gd, gdd only - no gamma)
trend_params = {"gd": params_hyperparams["gd"], "gdd": params_hyperparams["gdd"]}
trend = ravest.model.Trend(params=trend_params, t0=self.t0)
other_rv_obs += trend.radial_velocity(self.time)
# The mean model for the GP is all of the planets + the trend
rv_mean_obs = other_rv_obs + planet_rv_obs
# Set up GP for prediction
hyperparams = {hp: params_hyperparams[hp] for hp in self.gp_kernel.expected_hyperparams}
kernel = self.gp_kernel.build_kernel(hyperparams)
# Calculate per-instrument jitter for GP diagonal
jit2_verr2 = np.zeros(len(self.time))
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit = params_hyperparams[f"jit_{inst}"]
jit2_verr2[mask] = self.velerr[mask]**2 + jit**2
# Create GP conditioned on gamma-corrected residuals
residuals_obs = vel_corrected - rv_mean_obs
gp_obs = GaussianProcess(kernel=kernel, X=jnp.array(self.time), diag=jnp.array(jit2_verr2))
_, gp_cond = gp_obs.condition(y=jnp.array(residuals_obs), X_test=jnp.array(self.time))
gp_mean_obs = np.array(gp_cond.mean)
# Calculate data points for the plot: gamma-corrected data - (other planets + Trend + GP mean)
data_minus_others_obs = vel_corrected - (other_rv_obs + gp_mean_obs)
# Sort the data according to phase folding
data_minus_others_obs_sorted = data_minus_others_obs[inds]
verr_sorted = self.velerr[inds]
velerr_with_jit_sorted = velerr_with_jit[inds]
instrument_sorted = self.instrument[inds]
# Create figure with subplots (main plot + residuals)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 5),
gridspec_kw={'height_ratios': [3, 1], 'hspace': 0})
# Set up per-instrument colours
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)}
# Main phase plot (coloured by instrument)
for inst in self.unique_instruments:
mask = (instrument_sorted == inst)
ax1.errorbar(t_fold_sorted[mask], data_minus_others_obs_sorted[mask], yerr=verr_sorted[mask],
marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst],
markersize=8, zorder=4, label=inst)
# Jitter extension (faded, no label)
ax1.errorbar(t_fold_sorted[mask], data_minus_others_obs_sorted[mask], yerr=velerr_with_jit_sorted[mask],
marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3)
# Plot phase-folded model for this planet
ax1.plot(tsmooth_fold_sorted, planet_rv_smooth_sorted, label="Planet Model", color="black", zorder=2)
ax1.set_xlim(-0.5, 0.5)
ax1.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25
if ylabel_main:
ax1.set_ylabel(ylabel_main)
ax1.legend(loc="upper right")
if title:
ax1.set_title(title)
ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in')
ax1.tick_params(axis='y', direction='in')
# Set y-axis ticks automatically based on phase data range
ax1.yaxis.set_major_locator(AutoLocator())
ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax1.tick_params(axis='y', which='minor', direction='in', length=3)
# Annotate with planet info
K_value = params_hyperparams[f"K_{planet_letter}"]
P_label = param_key_to_latex(f"P_{planet_letter}")
K_label = param_key_to_latex(f"K_{planet_letter}")
s = f"Planet {planet_letter}\n{P_label}={P:.2f} d\n{K_label}={K_value:.2f} m/s"
ax1.annotate(s, xy=(0, 1), xycoords="axes fraction",
xytext=(+0.5, -0.5), textcoords="offset fontsize", va="top")
# Residuals plot (phase-folded, coloured by instrument)
residuals = data_minus_others_obs - planet_rv_obs
residuals_sorted = residuals[inds]
for inst in self.unique_instruments:
mask = (instrument_sorted == inst)
ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=verr_sorted[mask],
marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst],
markersize=8, zorder=4)
ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=velerr_with_jit_sorted[mask],
marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3)
ax2.axhline(0, color="k", linestyle="--", zorder=2)
ax2.set_xlim(-0.5, 0.5)
ax2.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25
# # Set symmetric y-limits for residuals
# max_abs_residual = np.max(np.abs(residuals + velerr_with_jit))
# ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1)
if xlabel:
ax2.set_xlabel(xlabel)
if ylabel_residuals:
ax2.set_ylabel(ylabel_residuals)
ax2.tick_params(axis='x', direction='in')
ax2.tick_params(axis='y', direction='in')
ax2.tick_params(axis='x', top=True, labeltop=False)
# Set y-axis ticks automatically based on residuals range
ax2.yaxis.set_major_locator(AutoLocator())
ax2.yaxis.set_minor_locator(AutoMinorLocator())
ax2.tick_params(axis='y', which='minor', direction='in', length=3)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def plot_posterior_rv(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, show_CI: bool = True, title: str | None = "Posterior predictions (with GP)", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "posterior_rv.png", dpi: int = 100, n_smooth: int = 1000) -> None:
"""Plot the posterior GP RV model with uncertainty bands from MCMC samples.
Calculates RV model predictions for each MCMC sample, then plots the median
with optional 68% CI (16th-84th percentile) uncertainty bands. Shows both the full model
(planetary signals + trend + GP) and residuals vs data.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
show_CI : bool, optional
Show 68.3% credible interval band (default: True)
title : str or None, optional
Title for the main RV plot (default: "Posterior predictions (with GP)"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot to path `fname` (default: False)
fname : str, optional
The path to save the plot to (default: "posterior_rv.png")
dpi : int, optional
The dpi to save the image at (default: 100)
n_smooth : int, optional
Number of points in smooth time grid for plotting model curves (default: 1000).
Reduce for faster plotting, increase for smoother curves.
"""
# Create smooth time curve for plotting
_tmin, _tmax = self.time.min(), self.time.max()
_trange = _tmax - _tmin
tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, n_smooth)
# Get samples for free parameters/hyperparameters and combine with fixed values
samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin)
# Combine with fixed parameters and fixed hyperparameters
params_hyperparams = samples_dict | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict
# Report number of effective samples
n_samples = len(samples_dict[list(samples_dict.keys())[0]])
print(f"Processing {n_samples} effective samples (after discard_start={discard_start}, discard_end={discard_end}, thin={thin})")
# Calculate all planets + trend RVs for both obs and smooth times
rv_all_planets_trend_matrix_obs = np.zeros((len(samples_dict[list(samples_dict.keys())[0]]), len(self.time)))
rv_all_planets_trend_matrix_smooth = np.zeros((len(samples_dict[list(samples_dict.keys())[0]]), len(tsmooth)))
# Add all planets
for planet_letter in self.planet_letters:
print(f"Calculating planet {planet_letter} RV at {len(self.time)} observed times...")
rv_all_planets_trend_matrix_obs += self.calculate_rv_planet_from_samples(planet_letter, self.time, discard_start, discard_end, thin)
print(f"Calculating planet {planet_letter} RV at {len(tsmooth)} smooth times...")
rv_all_planets_trend_matrix_smooth += self.calculate_rv_planet_from_samples(planet_letter, tsmooth, discard_start, discard_end, thin)
# Add trend
print(f"Calculating trend RV at {len(self.time)} observed times...")
rv_all_planets_trend_matrix_obs += self.calculate_rv_trend_from_samples(self.time, discard_start, discard_end, thin)
print(f"Calculating trend RV at {len(tsmooth)} smooth times...")
rv_all_planets_trend_matrix_smooth += self.calculate_rv_trend_from_samples(tsmooth, discard_start, discard_end, thin)
# Subtract per-instrument gamma offsets from data (using median gamma values)
vel_corrected = self.vel.copy()
for inst in self.unique_instruments:
mask = (self.instrument == inst)
g_key = f"g_{inst}"
if isinstance(params_hyperparams[g_key], np.ndarray):
vel_corrected[mask] -= np.median(params_hyperparams[g_key])
else:
vel_corrected[mask] -= params_hyperparams[g_key]
# Calculate GP mean at obs times and smooth times (conditioned on gamma-corrected residuals)
residuals_matrix_obs = vel_corrected - rv_all_planets_trend_matrix_obs
gp_mean_matrix_obs = np.zeros_like(residuals_matrix_obs)
gp_mean_matrix_smooth = np.zeros((residuals_matrix_obs.shape[0], len(tsmooth)))
for i in tqdm(range(residuals_matrix_obs.shape[0]), desc="Computing GP predictions"):
# Extract hyperparameters for this sample
sample_hyperparams = {}
for hp in self.gp_kernel.expected_hyperparams:
if isinstance(params_hyperparams[hp], np.ndarray):
sample_hyperparams[hp] = params_hyperparams[hp][i]
else:
sample_hyperparams[hp] = params_hyperparams[hp]
# Build GP kernel for this sample
kernel = self.gp_kernel.build_kernel(sample_hyperparams)
# Get per-instrument jitter values for this sample
jit2_verr2 = np.zeros(len(self.time))
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit_key = f"jit_{inst}"
if isinstance(params_hyperparams[jit_key], np.ndarray):
jit_value = params_hyperparams[jit_key][i]
else:
jit_value = params_hyperparams[jit_key]
jit2_verr2[mask] = self.velerr[mask]**2 + jit_value**2
# Create GP once for this sample
gp_obs = GaussianProcess(kernel=kernel, X=jnp.array(self.time), diag=jnp.array(jit2_verr2))
# Predict at observation times
_, gp_cond_obs = gp_obs.condition(y=jnp.array(residuals_matrix_obs[i]), X_test=jnp.array(self.time))
gp_mean_matrix_obs[i] = np.array(gp_cond_obs.mean)
# Predict at smooth times (reusing the same GP object)
_, gp_cond_smooth = gp_obs.condition(y=jnp.array(residuals_matrix_obs[i]), X_test=jnp.array(tsmooth))
gp_mean_matrix_smooth[i] = np.array(gp_cond_smooth.mean)
# Combine full model: planets + trend + GP
rv_full_matrix_obs = rv_all_planets_trend_matrix_obs + gp_mean_matrix_obs
rv_full_matrix_smooth = rv_all_planets_trend_matrix_smooth + gp_mean_matrix_smooth
# Calculate percentiles
rv_percentiles_smooth = np.percentile(rv_full_matrix_smooth, [15.85, 50, 84.15], axis=0)
rv_percentiles_obs = np.percentile(rv_full_matrix_obs, [15.85, 50, 84.15], axis=0)
# Calculate residuals using median model at data times (gamma-corrected data)
residuals = vel_corrected - rv_percentiles_obs[1]
# Get per-instrument jitter for error bars (use median from samples)
velerr_with_jit = np.zeros_like(self.velerr)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit_key = f"jit_{inst}"
if isinstance(params_hyperparams[jit_key], np.ndarray):
jit_median = np.median(params_hyperparams[jit_key])
else:
jit_median = params_hyperparams[jit_key]
velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit_median**2)
# Create figure with subplots (same layout as Fitter)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 5), gridspec_kw={'height_ratios': [3, 1], 'hspace': 0})
# Set up per-instrument colours
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)}
# Main RV plot (gamma-corrected data, coloured by instrument)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=self.velerr[mask],
marker=".", color=inst_colors[inst], ecolor=inst_colors[inst],
linestyle="None", markersize=8, zorder=4, label=inst)
# Jitter extension (faded, no label)
ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=velerr_with_jit[mask],
marker="None", ecolor=inst_colors[inst], linestyle="None",
alpha=0.5, zorder=3)
# Plot median model and uncertainty
ax1.plot(tsmooth, rv_percentiles_smooth[1], label="Model (Mean + GP)", color="black", zorder=2)
if show_CI:
ax1.fill_between(tsmooth, rv_percentiles_smooth[0], rv_percentiles_smooth[2], color="tab:gray", alpha=0.3, edgecolor="none", label="68.3% CI", zorder=1)
ax1.set_xlim(tsmooth[0], tsmooth[-1])
if ylabel_main:
ax1.set_ylabel(ylabel_main)
if title:
ax1.set_title(title)
ax1.legend(loc="upper right")
ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in')
ax1.tick_params(axis='y', direction='in')
# Set y-axis ticks
ax1.yaxis.set_major_locator(AutoLocator())
ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax1.tick_params(axis='y', which='minor', direction='in', length=3)
# Residuals plot (coloured by instrument)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
ax2.errorbar(self.time[mask], residuals[mask], yerr=self.velerr[mask],
marker=".", color=inst_colors[inst], ecolor=inst_colors[inst],
linestyle="None", markersize=8, zorder=4)
ax2.errorbar(self.time[mask], residuals[mask], yerr=velerr_with_jit[mask],
marker="None", ecolor=inst_colors[inst], linestyle="None",
alpha=0.5, zorder=3)
ax2.axhline(0, color="k", linestyle="--", zorder=2)
ax2.set_xlim(tsmooth[0], tsmooth[-1])
# Set symmetric y-limits for residuals plot
max_abs_residual = np.max(np.abs(residuals + velerr_with_jit))
ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1)
if xlabel:
ax2.set_xlabel(xlabel)
if ylabel_residuals:
ax2.set_ylabel(ylabel_residuals)
ax2.tick_params(axis='x', direction='in')
ax2.tick_params(axis='y', direction='in')
ax2.tick_params(axis='x', top=True, labeltop=False)
# Set y-axis ticks
ax2.yaxis.set_major_locator(AutoLocator())
ax2.yaxis.set_minor_locator(AutoMinorLocator())
ax2.tick_params(axis='y', which='minor', direction='in', length=3)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def plot_posterior_phase(self, planet_letter: str, discard_start: int = 0, discard_end: int = 0, thin: int = 1, show_CI: bool = True, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "posterior_phase.png", dpi: int = 100, n_smooth: int = 50) -> None:
"""Plot phase-folded GP RV model with uncertainty bands from MCMC samples.
Calculates phase-folded planetary signal with uncertainty bands calculated
from MCMC samples. Shows the target planet signal after removing trends
and other planets, with the GP component included in the data being plotted.
Parameters
----------
planet_letter : str
Letter identifying the planet to plot (e.g., 'b', 'c', 'd')
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
show_CI : bool, optional
Show 68.3% credible interval band (default: True)
title : str or None, optional
Title for the main phase plot (default: "Posterior Phase Plot (with GP) - Planet {planet_letter}"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot to path `fname` (default: False)
fname : str, optional
The path to save the plot to (default: "posterior_phase.png")
dpi : int, optional
The dpi to save the image at (default: 100)
n_smooth : int, optional
Number of points in smooth time grid for plotting model curves (default: 50).
Reduce for faster plotting, increase for smoother curves.
"""
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 5), gridspec_kw={'height_ratios': [3, 1], 'hspace': 0})
# Get samples for free parameters/hyperparameters and combine with fixed values
samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin)
# Combine with fixed parameters and fixed hyperparameters
params_hyperparams = samples_dict | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict
# Report number of effective samples
n_samples = len(samples_dict[list(samples_dict.keys())[0]])
print(f"Processing {n_samples} effective samples (after discard_start={discard_start}, discard_end={discard_end}, thin={thin})")
# Create smooth time array for plotting the model
_tmin, _tmax = self.time.min(), self.time.max()
_trange = _tmax - _tmin
tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, n_smooth)
# Get per-instrument jitter for error bars (use median from samples)
velerr_with_jit = np.zeros_like(self.velerr)
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit_key = f"jit_{inst}"
if isinstance(params_hyperparams[jit_key], np.ndarray):
jit_median = np.median(params_hyperparams[jit_key])
else:
jit_median = params_hyperparams[jit_key]
velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit_median**2)
# Subtract per-instrument gamma offsets from data (using median gamma values)
vel_corrected = self.vel.copy()
for inst in self.unique_instruments:
mask = (self.instrument == inst)
g_key = f"g_{inst}"
if isinstance(params_hyperparams[g_key], np.ndarray):
vel_corrected[mask] -= np.median(params_hyperparams[g_key])
else:
vel_corrected[mask] -= params_hyperparams[g_key]
# Get period value
_P = params_hyperparams[f'P_{planet_letter}']
# Get (or calculate) Tc for this planet for folding around
if "Tc" in self.parameterisation.pars:
_Tc = params_hyperparams[f"Tc_{planet_letter}"]
else:
# Fall back to default parameterisation conversion (as it has P, e, w and Tp, so we can definitely get Tc)
planet_params = {par: params_hyperparams[f"{par}_{planet_letter}"] for par in self.parameterisation.pars}
default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params)
_Tc = self.parameterisation.convert_tp_to_tc(default_params["Tp"], _P, default_params["e"], default_params["w"])
# just for the folding, take the median value of the P and Tc samples
Tc = np.median(_Tc)
P = np.median(_P)
# Phase fold both data and smooth times
t_fold_sorted, inds = ravest.model.fold_time_series(self.time, P, Tc)
tsmooth_fold_sorted, smooth_inds = ravest.model.fold_time_series(tsmooth, P, Tc)
# So, what do we want to do?
# We want to:
# For each sample:
# 1) Calculate this planet's RVs at obs times and at smooth times
# 2) Calculate all the other planets' RVs at obs times
# 3) Calculate trend at obs times
# 4) Calculate GP mean at obs times (conditioned on residuals, i.e. subtract combination of planets + trend). This will give us an array of n_samples x n_obs.
# 5) Add together the matrix of (other planets + trend + GP), we get n_samples by n_obs matrix
# 6) Take the median of this matrix, this gives us n_obs RV measurements
# 7) subtract this from the observed data
# 8) plot: that's our datapoints for this plot.
# 9) overplot: (also, overplot the errorbars caused by adding (median jitter) in quadrature too)
# 10) Planet RV model curve:
# for each sample: take planetary params, calculate RV (so we get matrix n_samples by n_obs)
# take percentiles of this matrix, plot median and 16th-84th percentile band
# 11) Residuals panel
# subtract planet (only) RV obs, from the matrix data - (other planets + trend + GP mean)
# 1) Calculate this planet's RVs at obs times and at smooth times
print(f"Calculating planet {planet_letter} RV at {len(self.time)} observed times...")
rv_this_planet_matrix_obs = self.calculate_rv_planet_from_samples(planet_letter, self.time, discard_start, discard_end, thin)
print(f"Calculating planet {planet_letter} RV at {len(tsmooth)} smooth times...")
rv_this_planet_matrix_smooth = self.calculate_rv_planet_from_samples(planet_letter, tsmooth, discard_start, discard_end, thin)
# 2) Calculate all the other planets' RVs at obs times
rv_other_planets_matrix_obs = np.zeros((rv_this_planet_matrix_obs.shape[0], len(self.time)))
for other_letter in self.planet_letters:
if other_letter != planet_letter:
print(f"Calculating planet {other_letter} RV at {len(self.time)} observed times...")
rv_other_planets_matrix_obs += self.calculate_rv_planet_from_samples(other_letter, self.time, discard_start, discard_end, thin)
# 3) Calculate trend at obs times
print(f"Calculating trend RV at {len(self.time)} observed times...")
rv_trend_matrix_obs = self.calculate_rv_trend_from_samples(self.time, discard_start, discard_end, thin)
# 4) Calculate GP mean at obs times (conditioned on residuals, i.e. subtract combination of planets + trend)
# First, we need to combine all the planets + trend
rv_all_planets_trend_matrix_obs = rv_this_planet_matrix_obs + rv_other_planets_matrix_obs + rv_trend_matrix_obs
# Second, we make a matrix of all of the residuals (based on the planetary+trends RVs from each sample)
# Using gamma-corrected data
residuals_matrix_obs = vel_corrected - rv_all_planets_trend_matrix_obs
# Third, we can now go over the samples, set up the GP with the hyperparameters
# for each sample, give the residuals from each planetary+trend parameter sample,
# and calculate the GP mean. Getting another matrix n_samples by n_obs
gp_mean_matrix_obs = np.zeros_like(residuals_matrix_obs)
# Loop over samples
for i in tqdm(range(residuals_matrix_obs.shape[0])):
# Extract hyperparameters for this sample
sample_hyperparams = {}
for hp in self.gp_kernel.expected_hyperparams:
if isinstance(params_hyperparams[hp], np.ndarray):
# Free hyperparameter - take i-th sample
sample_hyperparams[hp] = params_hyperparams[hp][i]
else:
# Fixed hyperparameter - use scalar value
sample_hyperparams[hp] = params_hyperparams[hp]
# Build GP kernel for this sample
kernel = self.gp_kernel.build_kernel(sample_hyperparams)
# Get per-instrument jitter values for this sample
jit2_verr2 = np.zeros(len(self.time))
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit_key = f"jit_{inst}"
if isinstance(params_hyperparams[jit_key], np.ndarray):
jit_value = params_hyperparams[jit_key][i]
else:
jit_value = params_hyperparams[jit_key]
jit2_verr2[mask] = self.velerr[mask]**2 + jit_value**2
# Create GP conditioned on residuals for this sample
gp_obs = GaussianProcess(kernel=kernel, X=jnp.array(self.time), diag=jnp.array(jit2_verr2))
_, gp_cond = gp_obs.condition(y=jnp.array(residuals_matrix_obs[i]), X_test=jnp.array(self.time))
gp_mean_matrix_obs[i] = np.array(gp_cond.mean)
# 5) Add together the matrix of (other planets + trend + GP), we get n_samples by n_obs matrix
rv_all_other_matrix_obs = rv_other_planets_matrix_obs + rv_trend_matrix_obs + gp_mean_matrix_obs
# 6) Take the median of this matrix, this gives us n_obs RV measurements
rv_all_other_median_obs = np.median(rv_all_other_matrix_obs, axis=0)
# 7) subtract this from the gamma-corrected data
data_minus_others = vel_corrected - rv_all_other_median_obs # single array
# 8) Plot: that's our datapoints for this plot.
# First: phase-fold data_minus_others using the indices we made earlier
data_minus_others_folded = data_minus_others[inds]
instrument_sorted = self.instrument[inds]
verr_sorted = self.velerr[inds]
velerr_with_jit_sorted = velerr_with_jit[inds]
# Set up per-instrument colours
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)}
# Now we can plot the data (coloured by instrument)
for inst in self.unique_instruments:
mask = (instrument_sorted == inst)
ax1.errorbar(t_fold_sorted[mask], data_minus_others_folded[mask], yerr=verr_sorted[mask],
marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst],
markersize=8, zorder=4, label=inst)
# 9) overplot the errorbars with jitter (faded, no label)
ax1.errorbar(t_fold_sorted[mask], data_minus_others_folded[mask], yerr=velerr_with_jit_sorted[mask],
marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3)
# 10) Planet RV model curve:
# for each sample: take planetary params, calculate RV (so we get matrix n_samples by n_obs)
# take percentiles of this matrix, plot median and (optionaly) the 68.3% interval band
rv_planet_smooth_percs = np.percentile(rv_this_planet_matrix_smooth, [15.85, 50, 84.15], axis=0)
ax1.plot(tsmooth_fold_sorted, rv_planet_smooth_percs[1][smooth_inds], color="black", label="Planet Model", zorder=2)
if show_CI:
ax1.fill_between(tsmooth_fold_sorted, rv_planet_smooth_percs[0][smooth_inds], rv_planet_smooth_percs[2][smooth_inds],
color="tab:gray", alpha=0.3, edgecolor="none", label="68.3% CI", zorder=1)
# 11) Residuals panel
residuals = data_minus_others - np.median(rv_this_planet_matrix_obs, axis=0)
residuals_sorted = residuals[inds]
ax2.axhline(0, color="k", linestyle="--", zorder=2)
for inst in self.unique_instruments:
mask = (instrument_sorted == inst)
ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=verr_sorted[mask],
marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst],
markersize=8, zorder=4)
ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=velerr_with_jit_sorted[mask],
marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3)
ax1.set_xlim(-0.5, 0.5)
ax1.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25
if ylabel_main:
ax1.set_ylabel(ylabel_main)
if title is None:
ax1.set_title(f"Posterior Phase Plot (with GP) - Planet {planet_letter}")
elif title:
ax1.set_title(title)
ax1.legend(loc="upper right")
ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in')
ax1.tick_params(axis='y', direction='in')
ax1.yaxis.set_major_locator(AutoLocator())
ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax1.tick_params(axis='y', which='minor', direction='in', length=3)
ax2.set_xlim(-0.5, 0.5)
ax2.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25
# Set symmetric y-limits for residuals plot
max_abs_residual = np.max(np.abs(residuals_sorted + velerr_with_jit_sorted))
ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1)
if xlabel:
ax2.set_xlabel(xlabel)
if ylabel_residuals:
ax2.set_ylabel(ylabel_residuals)
ax2.tick_params(axis='x', direction='in')
ax2.tick_params(axis='y', direction='in')
ax2.tick_params(axis='x', top=True, labeltop=False)
# Set y-axis ticks
ax2.yaxis.set_major_locator(AutoLocator())
ax2.yaxis.set_minor_locator(AutoMinorLocator())
ax2.tick_params(axis='y', which='minor', direction='in', length=3)
if save:
plt.savefig(fname=fname, dpi=dpi)
print(f"Saved {fname}")
plt.show()
[docs]
def plot_MAP_rv(self, map_result: scipy.optimize.OptimizeResult, title: str | None = "MAP RV (with GP)", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "MAP_rv.png", dpi: int = 100) -> None:
"""Plot the MAP RV model.
Uses the Maximum A Posteriori (MAP) parameter estimates to plot the
GP model including both mean model (planets + trend) and GP component.
Parameters
----------
map_result : scipy.optimize.OptimizeResult
Result from find_map_estimate()
title : str or None, optional
Plot title (default: "MAP RV (with GP)"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "MAP_rv.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
# Get MAP parameter and hyperparameter values from the optimization result
# map_result.x contains both free parameters and free hyperparameters
all_free_names = self.free_params_names + self.free_hyperparams_names
map_params_and_hyperparams = dict(zip(all_free_names, map_result.x))
# Combine with fixed parameters and fixed hyperparameters
all_params = map_params_and_hyperparams | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict
# Use helper function to create the plot
self._plot_rv(all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs]
def plot_MAP_phase(self, planet_letter: str, map_result: scipy.optimize.OptimizeResult, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "MAP_phase.png", dpi: int = 100) -> None:
"""Plot the MAP phase model.
Uses the Maximum A Posteriori (MAP) parameter estimates to plot the
phase-folded GP model for a specific planet, including both mean model
(planets + trend) and GP component.
Parameters
----------
planet_letter : str
Letter identifying the planet to plot (e.g., 'b', 'c', 'd')
map_result : scipy.optimize.OptimizeResult
Result from find_map_estimate()
title : str or None, optional
Plot title (default: f"MAP Phase Plot (with GP) - Planet {planet_letter}"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "MAP_phase.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
# Get MAP parameter and hyperparameter values from the optimization result
# map_result.x contains both free parameters and free hyperparameters
all_free_names = self.free_params_names + self.free_hyperparams_names
map_params_and_hyperparams = dict(zip(all_free_names, map_result.x))
# Combine with fixed parameters and fixed hyperparameters
all_params = map_params_and_hyperparams | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict
# Set default title if not provided
if title is None:
title = f"MAP Phase Plot (with GP) - Planet {planet_letter}"
# Use helper function to create the plot
self._plot_phase(planet_letter, all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs]
def plot_custom_rv(self, params_hyperparams: dict, title: str | None = "Custom GP RV Plot", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "custom_rv.png", dpi: int = 100) -> None:
"""Plot GP radial velocity data and model using custom parameter and hyperparameter values.
Allows plotting with arbitrary parameter and hyperparameter values for exploring
parameter space or comparing theoretical models. The GP component will be
conditioned on the residuals from the mean model (planets + trend).
Parameters
----------
params_hyperparams : dict
Dictionary of parameter and hyperparameter values to use for plotting.
Keys should match parameter and hyperparameter names, values should be floats.
Must include all required parameters and hyperparameters for the current
parameterisation and GP kernel.
title : str or None, optional
Plot title (default: "Custom GP RV Plot"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "custom_rv.png")
dpi : int, optional
Resolution for saving (default: 100)
Examples
--------
>>> # Plot with custom values (must include all required parameters + hyperparameters)
>>> gpfitter.plot_custom_rv({"P_b": 4.25, "K_b": 55.0, "e_b": 0.1,
... "w_b": 1.57, "Tc_b": 2456325.5,
... "g": -10.2, "gd": 0.0, "gdd": 0.0, "jit": 2.0,
... "gp_amp": 15.0, "gp_lambda_e": 50.0,
... "gp_lambda_p": 0.5, "gp_period": 25.0})
"""
# Validate that all required parameters and hyperparameters are present
expected_params = set(self.free_params_names + list(self.fixed_params_names))
expected_hyperparams = set(self.free_hyperparams_names + list(self.fixed_hyperparams_names))
expected_all = expected_params | expected_hyperparams
provided_params = set(params_hyperparams.keys())
missing_params = expected_all - provided_params
if missing_params:
raise ValueError(f"Missing required parameters/hyperparameters: {missing_params}")
# Use helper function to create the plot
self._plot_rv(params_hyperparams, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs]
def plot_custom_phase(self, planet_letter: str, params_hyperparams: dict, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "custom_phase.png", dpi: int = 100) -> None:
"""Plot GP phase-folded radial velocity data and model using custom parameter and hyperparameter values.
Allows plotting phase-folded data with arbitrary parameter and hyperparameter values
for exploring parameter space or comparing theoretical models. The GP component
will be conditioned on the residuals from the mean model (planets + trend).
Parameters
----------
planet_letter : str
Letter identifying the planet to plot (e.g., 'b', 'c', 'd')
params_hyperparams : dict
Dictionary of parameter and hyperparameter values to use for plotting.
Keys should match parameter and hyperparameter names, values should be floats.
Must include all required parameters and hyperparameters for the current
parameterisation and GP kernel.
title : str or None, optional
Plot title (default: f"Custom GP Phase Plot - Planet {planet_letter}"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "custom_phase.png")
dpi : int, optional
Resolution for saving (default: 100)
Examples
--------
>>> # Plot phase curve with custom values
>>> gpfitter.plot_custom_phase("b", {"P_b": 4.25, "K_b": 55.0, "e_b": 0.1,
... "w_b": 1.57, "Tc_b": 2456325.5,
... "g": -10.2, "gd": 0.0, "gdd": 0.0, "jit": 2.0,
... "gp_amp": 15.0, "gp_lambda_e": 50.0,
... "gp_lambda_p": 0.5, "gp_period": 25.0})
"""
# Validate that all required parameters and hyperparameters are present
expected_params = set(self.free_params_names + list(self.fixed_params_names))
expected_hyperparams = set(self.free_hyperparams_names + list(self.fixed_hyperparams_names))
expected_all = expected_params | expected_hyperparams
provided_params = set(params_hyperparams.keys())
missing_params = expected_all - provided_params
if missing_params:
raise ValueError(f"Missing required parameters/hyperparameters: {missing_params}")
# Set default title if not provided
if title is None:
title = f"Custom GP Phase Plot - Planet {planet_letter}"
# Use helper function to create the plot
self._plot_phase(planet_letter, params_hyperparams, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs]
def plot_best_sample_rv(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = "Best Sample RV Plot (with GP)", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "best_sample_rv.png", dpi: int = 100) -> None:
"""Plot radial velocity data and model using parameter and hyperparameter values from the MCMC sample with highest log probability.
This is useful for comparing with plot_MAP_rv() to diagnose potential issues with
MAP convergence or MCMC mixing. The two plots should be very similar if both
MAP and MCMC are working correctly.
Parameters
----------
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
title : str or None, optional
Title for the main RV plot (default: "Best Sample RV Plot (with GP)"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "best_sample_rv.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
# Get parameter and hyperparameter values from best sample
best_sample_params_hyperparams = self.get_sample_with_best_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin)
# Combine with fixed parameters and fixed hyperparameters
all_params = best_sample_params_hyperparams | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict
# Use helper function to create the plot
self._plot_rv(all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs]
def plot_best_sample_phase(self, planet_letter: str, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "best_sample_phase.png", dpi: int = 100) -> None:
"""Plot phase-folded radial velocity data and model using parameter and hyperparameter values from the MCMC sample with highest log probability.
This is useful for comparing with plot_MAP_phase() to diagnose potential issues with
MAP convergence or MCMC mixing. The two plots should be very similar if both
MAP and MCMC are working correctly.
Parameters
----------
planet_letter : str
Letter identifying the planet to plot (e.g., 'b', 'c', 'd')
discard_start : int, optional
Discard the first `discard_start` steps from the start of the chain (default: 0)
discard_end : int, optional
Discard the last `discard_end` steps from the end of the chain (default: 0)
thin : int, optional
Use only every `thin` steps from the chain (default: 1)
title : str or None, optional
Title for the main phase plot (default: "Best Sample Phase Plot (with GP) - Planet {planet_letter}"). Set to None or "" to skip.
ylabel_main : str or None, optional
Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip.
xlabel : str or None, optional
X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip.
ylabel_residuals : str or None, optional
Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip.
save : bool, optional
Save the plot (default: False)
fname : str, optional
Filename to save (default: "best_sample_phase.png")
dpi : int, optional
Resolution for saving (default: 100)
"""
# Get parameter and hyperparameter values from best sample
best_sample_params_hyperparams = self.get_sample_with_best_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin)
# Combine with fixed parameters and fixed hyperparameters
all_params = best_sample_params_hyperparams | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict
# Set default title if not provided
if title is None:
title = f"Best Sample Phase Plot (with GP) - Planet {planet_letter}"
# Use helper function to create the plot
self._plot_phase(planet_letter, all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals,
save=save, fname=fname, dpi=dpi)
[docs]
def calculate_rv_planet_from_samples(self, planet_letter: str, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = True) -> np.ndarray:
"""Calculate planetary RV for each MCMC sample.
This calculates RV(params_i) for each MCMC sample i, preserving
parameter correlations. This differs from using median parameters
which may not represent actual samples from the posterior.
Parameters
----------
planet_letter : str
Planet letter (e.g., 'b', 'c')
times : np.ndarray
Time points to calculate RV at
discard_start : int, optional
Discard first N steps (default: 0)
discard_end : int, optional
Discard last N steps (default: 0)
thin : int, optional
Use every Nth sample (default: 1)
progress : bool, optional
Show progress bar (default: True)
Returns
-------
np.ndarray
Shape (n_samples, len(times)) - RV for each sample
"""
samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
planet_rvs = np.zeros((len(samples), len(times)))
iterator = tqdm(enumerate(samples), total=len(samples), disable=not progress, desc=f"Calculating planet {planet_letter} RV from samples")
for i, combined_sample in iterator:
# Build complete params dict for this sample (includes hyperparams)
params = self.build_params_dict(combined_sample)
# Use custom method
planet_rvs[i, :] = self.calculate_rv_planet_custom(planet_letter, times, params)
return planet_rvs
[docs]
def calculate_rv_trend_from_samples(self, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = True) -> np.ndarray:
"""Calculate trend RV for each MCMC sample.
This calculates RV_trend(params_i) for each MCMC sample i, preserving
parameter correlations. This differs from using median parameters
which may not represent actual samples from the posterior.
Parameters
----------
times : np.ndarray
Time points to calculate RV at
discard_start : int, optional
Discard first N steps (default: 0)
discard_end : int, optional
Discard last N steps (default: 0)
thin : int, optional
Use every Nth sample (default: 1)
progress : bool, optional
Show progress bar (default: True)
Returns
-------
np.ndarray
Shape (n_samples, len(times)) - Trend RV for each sample
"""
samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
trend_rvs = np.zeros((len(samples), len(times)))
iterator = tqdm(enumerate(samples), total=len(samples), disable=not progress, desc="Calculating trend RV from samples")
for i, combined_sample in iterator:
# Build complete params dict for this sample (includes hyperparams)
params = self.build_params_dict(combined_sample)
# Use custom method
trend_rvs[i, :] = self.calculate_rv_trend_custom(times, params)
return trend_rvs
[docs]
def calculate_rv_gp_from_samples(self, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = True) -> np.ndarray:
"""Calculate GP component for each MCMC sample.
This calculates the GP mean at the specified times for each MCMC sample,
preserving parameter and hyperparameter correlations from the posterior.
The GP is conditioned on residuals from the observed RV data, while
predictions are called at the requested `times`.
Parameters
----------
times : np.ndarray
Time points to calculate GP at
discard_start : int, optional
Discard first N steps (default: 0)
discard_end : int, optional
Discard last N steps (default: 0)
thin : int, optional
Use every Nth sample (default: 1)
progress : bool, optional
Show progress bar (default: True)
Returns
-------
np.ndarray
Shape (n_samples, len(times)) - GP component for each sample
"""
samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True)
gp_components = np.zeros((len(samples), len(times)))
# Pre-calculate mean RV (Trend + Planets) at data times for all samples
# These are used to condition the GP
mean_rv_at_data = self.calculate_rv_trend_from_samples(self.time, discard_start, discard_end, thin)
for planet_letter in self.planet_letters:
mean_rv_at_data += self.calculate_rv_planet_from_samples(planet_letter, self.time, discard_start, discard_end, thin)
iterator = tqdm(enumerate(samples), total=len(samples), disable=not progress, desc="Calculating GP from samples")
for i, combined_sample in iterator:
# Build complete params dict for this sample (includes hyperparams)
params = self.build_params_dict(combined_sample)
# Use custom method
gp_components[i, :] = self.calculate_rv_gp_custom(times, params)
return gp_components
[docs]
def calculate_rv_total_from_samples(self, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = True) -> np.ndarray:
"""Calculate total RV (planets + trend + GP) for each MCMC sample.
This calculates RV_total(params_i) for each MCMC sample i, preserving
parameter correlations. This differs from using median parameters
which may not represent actual samples from the posterior.
Parameters
----------
times : np.ndarray
Time points to calculate RV at
discard_start : int, optional
Discard first N steps (default: 0)
discard_end : int, optional
Discard last N steps (default: 0)
thin : int, optional
Use every Nth sample (default: 1)
progress : bool, optional
Show progress bar (default: True)
Returns
-------
np.ndarray
Shape (n_samples, len(times)) - Total RV for each sample
"""
# Calculate trend + planets at requested times (for output)
total_rvs = self.calculate_rv_trend_from_samples(times, discard_start, discard_end, thin)
for planet_letter in self.planet_letters:
planet_rvs = self.calculate_rv_planet_from_samples(planet_letter, times, discard_start, discard_end, thin)
total_rvs += planet_rvs
# Calculate the GP component (conditioned on residuals from mean_rv_at_data)
gp_rvs = self.calculate_rv_gp_from_samples(times=times, discard_start=discard_start, discard_end=discard_end, thin=thin, progress=progress)
# Add the GP component
total_rvs += gp_rvs
return total_rvs
[docs]
def calculate_rv_planet_custom(self, planet_letter: str, times: np.ndarray, params: dict[str, float]) -> np.ndarray:
"""Calculate planetary RV for a single set of custom parameters.
Parameters
----------
planet_letter : str
Planet letter to calculate RV for (e.g., 'b', 'c')
times : np.ndarray
Time points to calculate RV at
params : dict[str, float]
Complete parameter dictionary (free + fixed parameters).
Can be created using build_params_dict().
Returns
-------
np.ndarray
Planetary RV values at the requested times
Examples
--------
>>> # Using MAP result
>>> map_result = gpfitter.find_map_estimate()
>>> params = gpfitter.build_params_dict(map_result.x)
>>> planet_rv = gpfitter.calculate_rv_planet_custom('b', times, params)
>>>
>>> # Using best lnprob sample
>>> best_params = gpfitter.get_sample_with_best_lnprob(discard_start=1000)
>>> params = gpfitter.build_params_dict(best_params)
>>> planet_rv = gpfitter.calculate_rv_planet_custom('b', times, params)
"""
# Extract planet parameters
planet_params = {}
for par in self.parameterisation.pars:
key = f"{par}_{planet_letter}"
planet_params[par] = params[key]
# Calculate planet RV
planet = ravest.model.Planet(planet_letter, self.parameterisation, planet_params)
return planet.radial_velocity(times)
[docs]
def calculate_rv_trend_custom(self, times: np.ndarray, params: dict[str, float]) -> np.ndarray:
"""Calculate trend RV for a single set of custom parameters.
Parameters
----------
times : np.ndarray
Time points to calculate RV at
params : dict[str, float]
Complete parameter dictionary (free + fixed parameters).
Can be created using build_params_dict().
Returns
-------
np.ndarray
Trend RV values at the requested times
Examples
--------
>>> # Using MAP result
>>> map_result = gpfitter.find_map_estimate()
>>> params = gpfitter.build_params_dict(map_result.x)
>>> trend_rv = gpfitter.calculate_rv_trend_custom(times, params)
"""
# Calculate trend RV (no gamma offset - that's per-instrument)
trend = ravest.model.Trend(params={"gd": params["gd"], "gdd": params["gdd"]}, t0=self.t0)
return trend.radial_velocity(times)
[docs]
def calculate_rv_gp_custom(self, times: np.ndarray, params: dict[str, float]) -> np.ndarray:
"""Calculate GP component for a single set of custom parameters.
The GP is conditioned on residuals (observed data - mean RV model) where
the mean RV model is the sum of the systemic trend and all planetary signals.
Parameters
----------
times : np.ndarray
Time points to calculate GP at
params : dict[str, float]
Complete parameter and hyperparameter dictionary.
Must include both standard parameters (for trend/planets/jit) and
GP hyperparameters (for kernel). Can be created using build_params_dict().
Returns
-------
np.ndarray
GP RV values at the requested times
Examples
--------
>>> # Using MAP result
>>> map_result = gpfitter.find_map_estimate()
>>> params = gpfitter.build_params_dict(map_result.x)
>>> gp_rv = gpfitter.calculate_rv_gp_custom(times, params)
>>>
>>> # Using best lnprob sample
>>> best_params = gpfitter.get_sample_with_best_lnprob(discard_start=1000)
>>> params = gpfitter.build_params_dict(best_params)
>>> gp_rv = gpfitter.calculate_rv_gp_custom(times, params)
"""
# Separate params and hyperparams
hyperparam_values = {name: params[name] for name in self.gp_kernel.expected_hyperparams}
# Build GP kernel with per-instrument jitter
kernel = self.gp_kernel.build_kernel(hyperparam_values)
jit2_verr2 = np.zeros(len(self.time))
for inst in self.unique_instruments:
mask = (self.instrument == inst)
jit = params[f"jit_{inst}"]
jit2_verr2[mask] = self.velerr[mask]**2 + jit**2
gp = GaussianProcess(kernel=kernel, X=jnp.array(self.time), diag=jnp.array(jit2_verr2))
# Calculate mean RV (Trend + Planets) at data times
mean_rv_at_data = self.calculate_rv_trend_custom(self.time, params)
for planet_letter in self.planet_letters:
mean_rv_at_data += self.calculate_rv_planet_custom(planet_letter, self.time, params)
# Subtract per-instrument gamma from data before calculating residuals
vel_corrected = self.vel.copy()
for inst in self.unique_instruments:
mask = (self.instrument == inst)
vel_corrected[mask] -= params[f"g_{inst}"]
# Calculate residuals for conditioning (gamma-corrected)
residuals = vel_corrected - mean_rv_at_data
# Condition GP on residuals and predict at requested times
_, gp_cond = gp.condition(y=jnp.array(residuals), X_test=jnp.array(times))
return np.array(gp_cond.mean)
[docs]
def calculate_rv_total_custom(self, times: np.ndarray, params: dict[str, float]) -> np.ndarray:
"""Calculate total RV (trend + all planets + GP) for a single set of custom parameters.
Useful for calculating RV with specific parameter values (e.g., best lnprob
sample, median parameters, or experimental values).
Parameters
----------
times : np.ndarray
Time points to calculate RV at
params : dict[str, float]
Complete parameter and hyperparameter dictionary.
Must include both standard parameters (for trend/planets/jit) and
GP hyperparameters (for kernel). Can be created using build_params_dict().
Returns
-------
np.ndarray
Total RV values (trend + all planets + GP) at the requested times
Examples
--------
>>> # Using MAP result
>>> map_result = gpfitter.find_map_estimate()
>>> params = gpfitter.build_params_dict(map_result.x)
>>> total_rv = gpfitter.calculate_rv_total_custom(times, params)
"""
# Calculate trend + planets
total_rv = self.calculate_rv_trend_custom(times, params)
for planet_letter in self.planet_letters:
planet_rv = self.calculate_rv_planet_custom(planet_letter, times, params)
total_rv += planet_rv
# Add GP component
gp_rv = self.calculate_rv_gp_custom(times, params)
total_rv += gp_rv
return total_rv
[docs]
class GPLogPosterior:
"""Log posterior probability for GP MCMC sampling.
Combines GP log likelihood and log priors for both parameters and hyperparameters.
"""
def __init__(
self,
planet_letters: list[str],
parameterisation: Parameterisation,
gp_kernel: GPKernel,
priors: dict[str, Callable[[float], float]],
hyperpriors: dict[str, Callable[[float], float]],
fixed_params: dict[str, float],
fixed_hyperparams: dict[str, float],
free_params_names: list[str],
free_hyperparams_names: list[str],
time: np.ndarray,
vel: np.ndarray,
velerr: np.ndarray,
t0: float,
instrument: np.ndarray,
unique_instruments: list[str],
) -> None:
"""Initialize the GPLogPosterior object.
Parameters
----------
planet_letters : list[str]
List of single-character planet identifiers.
parameterisation : Parameterisation
The orbital parameterisation to use.
gp_kernel : GPKernel
The Gaussian Process kernel to use.
priors : dict[str, Callable[[float], float]]
Dictionary mapping parameter names to their prior probability functions.
hyperpriors : dict[str, Callable[[float], float]]
Dictionary mapping hyperparameter names to their prior probability functions.
fixed_params : dict[str, float]
Dictionary of fixed parameter values.
fixed_hyperparams : dict[str, float]
Dictionary of fixed hyperparameter values.
free_params_names : list[str]
List of free parameter names to sample.
free_hyperparams_names : list[str]
List of free hyperparameter names to sample.
time : np.ndarray
Time of each observation [days].
vel : np.ndarray
Radial velocity at each time [m/s].
velerr : np.ndarray
Uncertainty on the radial velocity at each time [m/s].
t0 : float
Reference time for the trend [days].
instrument : np.ndarray
Instrument label for each observation.
unique_instruments : list[str]
List of unique instrument names.
"""
self.planet_letters = planet_letters
self.parameterisation = parameterisation
self.gp_kernel = gp_kernel
self.priors = priors
self.hyperpriors = hyperpriors
self.fixed_params = fixed_params
self.fixed_hyperparams = fixed_hyperparams
self.free_params_names = free_params_names
self.free_hyperparams_names = free_hyperparams_names
self.time = time
self.vel = vel
self.velerr = velerr
self.t0 = t0
self.instrument = instrument
self.unique_instruments = unique_instruments
# Create GP log-likelihood and GP log-prior objects for later
self.gp_log_likelihood = GPLogLikelihood(
time=self.time,
vel=self.vel,
velerr=self.velerr,
t0=self.t0,
instrument=self.instrument,
unique_instruments=self.unique_instruments,
planet_letters=self.planet_letters,
parameterisation=self.parameterisation,
gp_kernel=self.gp_kernel,
)
# Create LogPrior objects for parameters and hyperparameters
self.log_prior = LogPrior(self.priors)
self.log_hyperprior = LogPrior(self.hyperpriors)
[docs]
def _convert_params_for_prior_evaluation(self, free_params_dict: dict[str, float]) -> Dict[str, float]:
"""Convert free parameters for prior evaluation if needed.
Parameters
----------
free_params_dict : dict
Free parameters in current parameterisation
Returns
-------
dict
Parameters with names/values converted for prior evaluation
"""
# Three cases:
# Case 1: User is fitting in transformed parameterisation, but priors are in same transformed parameterisation
# Case 2: User is fitting in default parameterisation, and priors are also in default parameterisation
# Case 3: User is fitting in transformed parameterisation, but priors are in default parameterisation
# Simple detection: do prior keys match our current free parameter names?
prior_keys = set(self.priors.keys())
free_param_keys = set(self.free_params_names)
if prior_keys == free_param_keys:
# No conversion needed (Cases 1 & 2)
return free_params_dict
else:
# Conversion needed (Case 3) - convert to default parameterisation equivalents
# Start with just the non-planetary parameters that match
params_for_prior = {key: value for key, value in free_params_dict.items()
if key in prior_keys}
all_params = self.fixed_params | free_params_dict
# Convert each planet's parameters
for planet_letter in self.planet_letters:
# Get current planet parameters
planet_params = {par: all_params[f"{par}_{planet_letter}"]
for par in self.parameterisation.pars}
# Convert to default parameterisation
default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params)
# Add the converted parameter values for priors that need them
for default_par, value in default_params.items():
default_param_key = f"{default_par}_{planet_letter}"
if default_param_key in prior_keys: # Only add if we have a prior for it
params_for_prior[default_param_key] = value
return params_for_prior
[docs]
def log_probability(self, combined_params_hyperparams: Dict[str, float]) -> float:
"""Calculate log posterior probability for given free parameters and hyperparameters.
Parameters
----------
combined_params_hyperparams : Dict[str, float]
Combined dictionary of free parameters and hyperparameters
Returns
-------
float
Log posterior probability (log likelihood + log prior + log hyperprior)
"""
# Split the combined dictionary into parameters and hyperparameters
free_params_dict = {name: combined_params_hyperparams[name] for name in self.free_params_names}
free_hyperparams_dict = {name: combined_params_hyperparams[name] for name in self.free_hyperparams_names}
# Fast fail for invalid jitter (before expensive prior/likelihood calculations)
# We have to check jitter specifically because all other params will ultimately
# get checked/raise Exceptions when they are used to calculate an RV.
# Jitter doesn't directly contribute to calculated RV, so needs to be checked manually.
_all_params_for_ll = self.fixed_params | free_params_dict
for inst in self.unique_instruments:
if _all_params_for_ll[f"jit_{inst}"] < 0:
return -np.inf
# Fast fail for invalid GP hyperparameters
# This is a check for unphysical values, not for if they are within the hyperpriors or not
try:
all_hyperparams_values = self.fixed_hyperparams | free_hyperparams_dict
self.gp_kernel._validate_hyperparams_values(all_hyperparams_values)
except ValueError:
return -np.inf
# Evaluate priors on the free parameters. If any parameters are outside priors
# (i.e. priors are infinite), then fail fast by returning -inf early (before expensive likelihood calc).
# We attempt to convert free parameters (if needed) for prior evaluation
# This is for if the user is fitting in transformed parameterisation,
# but defining their priors in the default parameterisation
try:
params_for_prior = self._convert_params_for_prior_evaluation(free_params_dict)
lp = self.log_prior(params_for_prior)
except ValueError:
# Invalid parameter conversion (e.g., unphysical eccentricity)
return -np.inf
if not np.isfinite(lp):
return -np.inf
# Evaluate hyperpriors on the free hyperparameters - fail fast if any hyperparameters are outside priors
lhp = self.log_hyperprior(free_hyperparams_dict)
if not np.isfinite(lhp):
return -np.inf
# Calculate GP log-likelihood with all parameters and hyperparameters
all_params = self.fixed_params | free_params_dict
all_hyperparams = self.fixed_hyperparams | free_hyperparams_dict
ll = self.gp_log_likelihood(params=all_params, hyperparams=all_hyperparams)
# Return combined log-posterior (log-likelihood + log-prior + log-hyperprior)
logprob = ll + lp + lhp
return logprob
[docs]
def _negative_log_probability_for_MAP(self, combined_free_params_hyperparams_vals: list[float]) -> float:
"""For MAP: run __call__ only passing in a list, not dict, of params.
Because scipy.optimize.minimise only takes list of values, not a dict,
we need to assign the values back to their corresponding keys, and pass
that to __call__().
This does not check that the values are in the correct order, it is
assumed. As we're dealing with dicts, this hopefully is the case.
Parameters
----------
combined_free_params_hyperparams_vals : list
Combined list of free parameter and free hyperparameter values
"""
# Split the list back into params values and hyperparams values
n_params = len(self.free_params_names)
params_values = combined_free_params_hyperparams_vals[:n_params]
hyperparams_values = combined_free_params_hyperparams_vals[n_params:]
# Create combined dict from the names and values
# (Assumes the order of names matches the order of values)
params_dict = dict(zip(self.free_params_names, params_values))
hyperparams_dict = dict(zip(self.free_hyperparams_names, hyperparams_values))
combined_dict = params_dict | hyperparams_dict
# Calculate *negative* log_probability (MAP is backwards from MCMC)
logprob = self.log_probability(combined_dict)
neg_logprob = -logprob
# Handle -inf log_probability to prevent scipy RuntimeWarnings during optimisation
# scipy's optimizer can't handle -inf values in arithmetic operations
# (This does mean there is a non-zero chance we could end up returning a solution that doesn't satisfy the prior functions)
if not np.isfinite(neg_logprob):
return 1e30 # Very large finite number instead of +inf
return neg_logprob
[docs]
class GPLogLikelihood:
"""GP version of Log likelihood calculation for radial velocity data.
Calculates log likelihood given RV model parameters and data, and GP hyperparameters.
"""
def __init__(
self,
time: np.ndarray,
vel: np.ndarray,
velerr: np.ndarray,
t0: float,
instrument: np.ndarray,
unique_instruments: list[str],
planet_letters: list[str],
parameterisation: Parameterisation,
gp_kernel: GPKernel,
) -> None:
self.time = time
self.vel = vel
self.velerr = velerr
self.t0 = t0
self.instrument = instrument
self.unique_instruments = unique_instruments
self.planet_letters = planet_letters
self.parameterisation = parameterisation
self.gp_kernel = gp_kernel
# Convert data to JAX array for tinygp
self.jax_time = jnp.array(self.time)
self.jax_vel = jnp.array(self.vel)
self.jax_velerr = jnp.array(self.velerr)
# Precompute a per-observation integer index array (same pattern as LogLikelihood).
# For each observation, store which instrument it came from as an integer:
# e.g. unique_instruments = ["HARPS", "ESPRESSO"]
# instrument = ["HARPS", "HARPS", "ESPRESSO", "HARPS", ...]
# _instrument_indices = [0, 0, 1, 0, ...]
# This lets us use JAX fancy indexing to expand per-instrument values to length-N
# arrays in one operation, rather than looping with boolean mask slices.
_inst_to_idx = {inst: i for i, inst in enumerate(self.unique_instruments)}
self._instrument_indices = jnp.array([_inst_to_idx[inst] for inst in self.instrument])
# Precompute parameter key strings for gamma and jitter lookups.
# These strings (e.g. "g_HARPS", "jit_ESPRESSO") are constant for the lifetime
# of this object — precomputing them avoids rebuilding f-strings on every call.
self._gamma_keys = [f"g_{inst}" for inst in self.unique_instruments]
self._jitter_keys = [f"jit_{inst}" for inst in self.unique_instruments]
# Precompute jax_velerr squared — constant (as observed data doesn't change) so no need to recalculate every time
self._velerr_sq = self.jax_velerr ** 2
[docs]
def _calculate_mean_model(self, params: Dict[str, float]) -> jnp.ndarray:
"""Calculate the Keplerian RV model (the mean function for the GP).
Takes planetary parameters, trend parameters, and per-instrument gamma offsets.
Parameters
----------
params : Dict[str, float]
Dictionary of all parameter values
Returns
-------
jnp.ndarray
Mean model RV values at observation times
"""
rv_total = jnp.zeros(len(self.time))
# Step 1: Calculate RV contributions from each planet
for letter in self.planet_letters:
# get just the parameters for this planet (and strip the _letter suffix from the keys)
_this_planet_params = {
par: params[f"{par}_{letter}"]
for par in self.parameterisation.pars
}
try:
_this_planet = ravest.model.Planet(letter, self.parameterisation, _this_planet_params)
_this_planet_rv = _this_planet.radial_velocity(self.time)
except ValueError:
# Planet.__init__ validates parameters and raises ValueError for invalid params
return -np.inf # fail-fast: return -inf log-likelihood
# add this planet's RV contribution to the total
rv_total += _this_planet_rv
# Step 2: Calculate and add the RV from the system Trend (no gamma - that's per-instrument)
_trend_params = {"gd": params["gd"], "gdd": params["gdd"]}
_this_trend = ravest.model.Trend(params=_trend_params, t0=self.t0)
_rv_trend = _this_trend.radial_velocity(self.time)
rv_total += jnp.array(_rv_trend)
# Step 3: Add per-instrument gamma offsets using vectorised fancy indexing.
# Build a small array of gamma values, one per instrument (length K), then use
# _instrument_indices to select the right gamma for each of the N observations.
# JAX arrays are immutable so we use addition rather than in-place update.
gamma_per_instrument = jnp.array([params[k] for k in self._gamma_keys])
gamma_at_each_obs = gamma_per_instrument[self._instrument_indices]
rv_total = rv_total + gamma_at_each_obs
return rv_total
[docs]
@staticmethod
@jax.jit
def _compute_gp_log_likelihood(
kernel: kernels.Kernel,
time_array: jnp.ndarray,
vel_array: jnp.ndarray,
verr_squared_array: jnp.ndarray,
mean_model: jnp.ndarray
) -> float:
"""JIT-compiled GP log likelihood computation.
This is the expensive numerical part that benefits from JIT compilation.
"""
gp = GaussianProcess(kernel=kernel, X=time_array, diag=verr_squared_array)
residuals = vel_array - mean_model
return gp.log_probability(y=residuals)
[docs]
def __call__(self, params: Dict[str, float], hyperparams: Dict[str, float]) -> float:
"""Calculate GP log likelihood for given parameters and hyperparameters.
Parameters
----------
params : Dict[str, float]
Dictionary of all parameter values
hyperparams : Dict[str, float]
Dictionary of all hyperparameter values
Returns
-------
float
Log likelihood value
"""
# Calculate mean model (RV signal from planets + system trend + per-instrument gamma)
mean_model = self._calculate_mean_model(params)
# Check if mean model calculation failed
# (no point doing expensive GP calculation if we don't need to)
if not jnp.isfinite(mean_model).all():
return -np.inf
# Build GP kernel with hyperparameters
kernel = self.gp_kernel.build_kernel(hyperparams)
# Add per-instrument jitter to observational uncertainties using vectorised fancy indexing.
# Each instrument has its own jitter value. We need to pair each of the N observations
# with its instrument's jitter. We do this in two steps:
# 1. Build a small array of jitter values, one per instrument (length K)
# 2. Use _instrument_indices to select the right jitter for each observation (length N)
# N.B. we don't sqrt here - tinygp diag wants variance, not stddev
jitter_per_instrument = jnp.array([params[k] for k in self._jitter_keys])
jitter_at_each_obs = jitter_per_instrument[self._instrument_indices]
velerr_jit_squared = self._velerr_sq + jitter_at_each_obs**2
# Use JIT-compiled helper for the expensive GP computation
return self._compute_gp_log_likelihood(
kernel=kernel,
time_array=self.jax_time,
vel_array=self.jax_vel,
verr_squared_array=velerr_jit_squared,
mean_model=mean_model
)