Source code for ravest.fit

"""Radial velocity fitting using MCMC and MAP estimation.

This module provides the main Fitter class for fitting radial velocity data
to planetary models using various parameterisations.
"""
# fit.py
import logging
import multiprocessing as mp
import os
import warnings
from typing import Callable, Dict, Optional

# Many builds of NumPy are linked against OpenBLAS or MKL, which can use multiple threads
# This can cause problems with multiprocessing (that we use to speed up emcee)
# So we set these to only use one thread
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

import corner
import emcee
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from matplotlib.ticker import AutoLocator, AutoMinorLocator, MultipleLocator
from scipy.optimize import minimize
from tinygp import GaussianProcess, kernels
from tqdm import tqdm

import ravest.model
from ravest.gp import GPKernel
from ravest.model import _njit_kepler_rv
from ravest.param import Parameter, Parameterisation, param_key_to_latex

# Enable 64-bit precision for better numerical accuracy
jax.config.update("jax_enable_x64", True)

logging.basicConfig(level=logging.INFO)


[docs] class Fitter: """Main class for fitting radial velocity data to planetary models. Supports MCMC sampling, MAP estimation, and various parameterisations. Handles multiple planets, trends, and jitter parameters. """ def __init__(self, planet_letters: list[str], parameterisation: Parameterisation) -> None: """Initialize the Fitter object. Parameters ---------- planet_letters : list[str] List of single-character planet identifiers (e.g., ['b', 'c', 'd']). Used to distinguish parameters for different planets in the system. parameterisation : Parameterisation The orbital parameterisation to use for fitting. Defines which orbital elements are used as free/fixed parameters (e.g., 'Default', 'EccentricityWind'). """ self.planet_letters = planet_letters self.parameterisation = parameterisation # Trigger numba JIT compilation before MCMC _dummy_M = np.linspace(0, 2 * np.pi, 10) _njit_kepler_rv(_dummy_M, 0.3, 10.0, 0.5) # Initialize parameter storage self._params: Dict[str, Parameter] = {} self._priors: Dict[str, Callable[[float], float]] = {}
[docs] def add_data( self, time: np.ndarray, vel: np.ndarray, velerr: np.ndarray, instrument: np.ndarray, t0: float, ) -> None: """Add the data to the Fitter object. Parameters ---------- time : array-like Time of each observation [days] vel : array-like Radial velocity at each time [m/s] velerr : array-like Uncertainty on the radial velocity at each time [m/s] instrument : array-like Instrument name for each observation (e.g., "HARPS", "HIRES") t0 : float Reference time for the trend [days]. Recommended to set this as mean or median of input `time` array. """ if not (len(time) == len(vel) == len(velerr) == len(instrument)): raise ValueError( "Time, velocity, uncertainty, and instrument arrays must be the same length." ) self.time = np.ascontiguousarray(time) self.vel = np.ascontiguousarray(vel) self.velerr = np.ascontiguousarray(velerr) self.instrument = np.asarray(instrument) self.unique_instruments = np.unique(self.instrument) self.t0 = t0
@property def params(self) -> Dict[str, Parameter]: """Parameters dictionary. Set via: fitter.params = param_dict.""" return self._params @params.setter def params(self, new_params: Dict[str, Parameter]) -> None: """Set parameters with a dict, checking all required params are present. You can update all or some of the parameters at once, example: >>> fitter.params = {"g": Parameter(1.0, "m/s"), "gd": Parameter(0.1, "m/s/d")} # only update trend parameters >>> fitter.params = {"P_c": Parameter(5.0, "d"), "K_c": Parameter(3.5, "m/s")} # only update some of planet C parameters Parameters ---------- new_params : dict Dictionary of new parameter values to set. The keys of this dictionary should match the parameter names expected by the Fitter object: all required parameters for the chosen parameterisation, with planet letters (not required for Trend or jitter parameters.) Raises ------ ValueError If any of the required parameters are missing or invalid. """ # Update the current _params dict with the new entries merged_params = dict(self._params) merged_params.update(new_params) # Validate the complete parameter set self._validate_complete_params(merged_params) # If validation passes, update the actual params self._params.update(new_params) # Update ndim based on new free parameters self.ndim = len(self.free_params_values) if self.ndim == 0: warnings.warn( "All parameters are fixed. MCMC methods (find_map_estimate, " "generate_initial_walker_positions_*, run_mcmc) require at least one " "free parameter (fixed=False).", UserWarning, stacklevel=2 ) @property def priors(self) -> dict: """Priors dictionary. Set via: fitter.priors = prior_dict.""" return self._priors @priors.setter def priors(self, new_priors: dict[str, Callable[[float], float]]) -> None: """Set prior functions using a dict, checking all required priors are present. Priors must be provided for all free parameters. You can set all priors at once or update individual priors. Parameters ---------- new_priors : dict Dictionary of prior functions to set. Keys should be parameter names that match free parameters, values should be callable prior functions. Examples -------- >>> from ravest.prior import Uniform >>> fitter.priors = {"K_b": Uniform(0, 100), "P_b": Uniform(1, 30)} Raises ------ ValueError If any required priors are missing, unexpected priors are provided, or initial parameter values are outside prior bounds. """ self._set_priors_with_validation(new_priors)
[docs] def _validate_complete_params(self, params: Dict[str, Parameter]) -> None: """Validate that params dict has required parameters, astrophysically valid values.""" # Require add_data() to have been called first (need unique_instruments) if not hasattr(self, "unique_instruments"): raise RuntimeError( "add_data() must be called before setting params " "(need instrument list for per-instrument parameters)" ) # Build complete set of expected parameters expected_params = set() # Add planetary parameters for planet_letter in self.planet_letters: for par_name in self.parameterisation.pars: expected_params.add(f"{par_name}_{planet_letter}") # Add trend parameters (system-wide, no gamma offset here) expected_params.update(["gd", "gdd"]) # Add per-instrument gamma offset and jitter parameters for inst in self.unique_instruments: expected_params.add(f"g_{inst}") expected_params.add(f"jit_{inst}") # Convert to sets for easy comparison provided_params = set(params.keys()) # Check for unexpected parameters unexpected_params = provided_params - expected_params if unexpected_params: # Give a specific hint if user is passing legacy single-instrument parameters legacy_params = unexpected_params & {"g", "jit"} if legacy_params: raise ValueError( f"Unexpected parameters: {unexpected_params}. " f"Single-instrument 'g' and 'jit' parameters are no longer supported. " f"Use per-instrument names instead, e.g. " f"{[f'g_{inst}' for inst in self.unique_instruments]} and " f"{[f'jit_{inst}' for inst in self.unique_instruments]}, " f"matching the instrument names passed to add_data()." ) raise ValueError( f"Unexpected parameters: {unexpected_params}. " f"Expected {len(expected_params)} parameters, got {len(provided_params)}" ) # Check for missing parameters missing_params = expected_params - provided_params if missing_params: raise ValueError( f"Missing required parameters: {missing_params}. " f"Expected {len(expected_params)} parameters, got {len(provided_params)}" ) # Validate astrophysical validity of all parameters params_values = {name: param.value for name, param in params.items()} self._validate_astrophysical_validity(params_values) # Validate parameter coupling constraints # i.e. if two parameters both need to be fixed or free together self._validate_parameter_coupling(params)
[docs] def _validate_astrophysical_validity(self, params_values: Dict[str, float]) -> None: """Validate that all parameter values are astrophysically valid.""" # First, check that ALL parameters are finite (not NaN or infinite) invalid_params = { name: value for name, value in params_values.items() if not np.isfinite(value) } if invalid_params: raise ValueError( "Invalid parameters detected: " + ", ".join(f"{k}={v}" for k, v in invalid_params.items()) ) # Validate planetary parameters for each planet for planet_letter in self.planet_letters: planet_params = {} for par_name in self.parameterisation.pars: key = f"{par_name}_{planet_letter}" planet_params[par_name] = params_values[key] # Validate this planet's parameters in current parameterisation self.parameterisation.validate_planetary_params(planet_params) # Validate trend parameters are finite real numbers (already checked above, but kept for clarity) for trend_param in ["gd", "gdd"]: trend_value = params_values[trend_param] if not np.isfinite(trend_value): raise ValueError(f"Invalid trend parameter {trend_param}: {trend_value} is not a finite real number") # Validate per-instrument parameters for inst in self.unique_instruments: # Gamma offset must be finite g_key = f"g_{inst}" if not np.isfinite(params_values[g_key]): raise ValueError(f"Invalid gamma offset {g_key}: {params_values[g_key]} is not finite") # Jitter must be >= 0 jit_key = f"jit_{inst}" if params_values[jit_key] < 0: raise ValueError(f"Invalid jitter {jit_key}: {params_values[jit_key]} < 0")
[docs] def _validate_parameter_coupling(self, params: Dict[str, Parameter]) -> None: """Validate parameter coupling constraints (e.g., secosw/sesinw must both be free or both fixed).""" for planet_letter in self.planet_letters: # Check secosw/sesinw coupling secosw_key = f"secosw_{planet_letter}" sesinw_key = f"sesinw_{planet_letter}" if secosw_key in params and sesinw_key in params: secosw_fixed = params[secosw_key].fixed sesinw_fixed = params[sesinw_key].fixed if secosw_fixed != sesinw_fixed: raise ValueError(f"Parameters {secosw_key} and {sesinw_key} must both be fixed or both be free") # Check ecosw/esinw coupling ecosw_key = f"ecosw_{planet_letter}" esinw_key = f"esinw_{planet_letter}" if ecosw_key in params and esinw_key in params: ecosw_fixed = params[ecosw_key].fixed esinw_fixed = params[esinw_key].fixed if ecosw_fixed != esinw_fixed: raise ValueError(f"Parameters {ecosw_key} and {esinw_key} must both be fixed or both be free")
[docs] def _set_priors_with_validation(self, new_priors: dict[str, Callable[[float], float]]) -> None: """Set priors with validation. Supports partial updates. Can be current or default parameterisation.""" # Create merged priors dict (in case user is only updating some priors, not all) merged_priors_dict = dict(self._priors) # get existing priors merged_priors_dict.update(new_priors) # overwrite with newer functions, if supplied provided_prior_param_names = set(merged_priors_dict.keys()) # There are two possibilities for priors: # 1. The prior has been given for the parameter, in the current parameterisation # (this can also include if the user is fitting in the default parameterisation) # 2. The prior has been given for the Default parameterisation's equivalent parameter instead # (e.g. e & w instead of secosw & sesinw, or Tp instead of Tc) # If not, then prior isn't given for either the Current or Default parameterisation, raise an Exception validated_priors = {} missing_priors = [] conflicts = [] # in the current parameterisation, which (free) parameters do we expect priors for? current_parameterisation_free_param_names = set(self.free_params_names) for free_param_name in current_parameterisation_free_param_names: if free_param_name in provided_prior_param_names: # Prior was provided for the param in the current parameterisation validated_priors[free_param_name] = merged_priors_dict[free_param_name] # Check if user ALSO provided equivalent default priors (conflict!) default_parameterisation_equivalent_free_param_names = self._get_default_parameterisation_equivalent_free_param_name(free_param_name) if default_parameterisation_equivalent_free_param_names: for equiv_param in default_parameterisation_equivalent_free_param_names: if equiv_param in provided_prior_param_names: conflicts.append((free_param_name, equiv_param)) else: # We haven't been provided the prior for the free parameter in the current parameterisation # So let's check if we were given the prior for the equivalent parameter in the default parameterisation instead default_parameterisation_equivalent_free_param_names = self._get_default_parameterisation_equivalent_free_param_name(free_param_name) # remember that one parameter in current parameterisation (e.g. secosw) might map to more than one equivalent in default parameterisation (e.g. both e & w) if default_parameterisation_equivalent_free_param_names and all(eq in provided_prior_param_names for eq in default_parameterisation_equivalent_free_param_names): # Found all required default equivalents for equiv in default_parameterisation_equivalent_free_param_names: validated_priors[equiv] = merged_priors_dict[equiv] else: # Missing prior for a free parameter in both the current parameterisation, and its equivalent in the default parameterisation if default_parameterisation_equivalent_free_param_names: missing_priors.append(f"{free_param_name} (or equivalent {default_parameterisation_equivalent_free_param_names})") else: missing_priors.append(free_param_name) # Check for conflicts after processing all parameters if conflicts: conflict_strs = [f"{current} vs {default}" for current, default in conflicts] raise ValueError(f"Conflicting priors provided for both current and default parameterisations: {', '.join(conflict_strs)}. Please provide priors for either the current parameterisation OR the equivalent default parameterisation, but not both.") if missing_priors: raise ValueError(f"Missing priors for parameters: {missing_priors}") # Check for unexpected priors - only allow priors that were validated above expected_prior_param_names = set(validated_priors.keys()) unexpected_prior_param_names = provided_prior_param_names - expected_prior_param_names if unexpected_prior_param_names: raise ValueError( f"Unexpected priors supplied for parameters: {unexpected_prior_param_names}. " f"Priors expected only for parameters: {expected_prior_param_names}" ) # Check parameter values work with priors self._check_params_values_against_priors(validated_priors, current_parameterisation_free_param_names) # Update the priors with the new values self._priors.update(new_priors)
[docs] def _get_default_parameterisation_equivalent_free_param_name(self, free_param: str) -> Optional[list[str]]: """Get the names of the default parameterisation equivalent parameter(s), for a single free parameter from the current parameterisation. Note this can be more than one: e.g. if you have secosw, this affects both e & w in the default parameterisation Whereas Tc just maps to Tp alone. Returns ------- list[str] | None - list[str]: equivalent parameter(s) in the default parameterisation - None: no mapping needed / no alternative priors to look for Raises ------ ValueError If `free_param` is not a recognised planet, instrument, or trend parameter. """ # No underscore (expected to be a system trend parameter) if '_' not in free_param: if free_param in ['gd', 'gdd']: # These are the same in all parameterisations return None else: raise ValueError(f"Unknown free parameter: {free_param}") # Contains underscore: Planetary or instrument parameters (with underscore before either planet letter or instrument name) # e.g. P_b, or Tc_c, or jit_HARPS else: base_param, suffix = free_param.split('_', 1) # split only on first underscore (some instrument names may have underscores too) # Planetary parameters: suffix is a planet letter if suffix in self.planet_letters: planet_letter = suffix if base_param in ['secosw', 'sesinw']: # Both secosw and sesinw map to e,w equivalents partner_param = 'sesinw' if base_param == 'secosw' else 'secosw' partner_key = f"{partner_param}_{planet_letter}" if partner_key in self.free_params_names: return [f"e_{planet_letter}", f"w_{planet_letter}"] elif base_param in ['ecosw', 'esinw']: # Both ecosw and esinw map to e,w equivalents partner_param = 'esinw' if base_param == 'ecosw' else 'ecosw' partner_key = f"{partner_param}_{planet_letter}" if partner_key in self.free_params_names: return [f"e_{planet_letter}", f"w_{planet_letter}"] elif base_param == 'Tc': # Tc can use Tp equivalent return [f"Tp_{planet_letter}"] elif base_param in ['P', 'K', 'e', 'w', 'Tp']: # These are default parameterisation parameters anyway return None else: # Suffix is a valid planet letter, but base parameter is unrecognised, so raise an error raise ValueError(f"Free parameter {free_param} has known planet letter {planet_letter} but unrecognised base parameter {base_param}.") # Instrument parameters: suffix is an instrument name elif suffix in self.unique_instruments: # The only instrument parameters are g and jit if base_param in ['g', 'jit']: # Per-instrument parameter (e.g., g_HARPS, jit_HIRES) # These are the same in all parameterisations return None else: raise ValueError(f"Free parameter {free_param} has known instrument name {suffix} but unrecognised base parameter {base_param} (expected 'g' or 'jit' only)") # Unknown: Suffix is present, but not a planet letter or instrument, so raise an error else: raise ValueError(f"Free parameter {free_param} has unrecognised suffix {suffix}, expected one of planet letters {self.planet_letters} or instrument names {self.unique_instruments}.")
[docs] def _check_params_values_against_priors(self, validated_priors: dict[str, Callable[[float], float]], current_free_param_names: list[str]) -> None: """Check parameter values against priors (including if Prior is for the Default parameterisation equivalent parameter).""" for prior_param_name, prior_function in validated_priors.items(): if prior_param_name in current_free_param_names: # This prior is in current parameterisation - check directly param_value = self.params[prior_param_name].value log_prior_probability = prior_function(param_value) if not np.isfinite(log_prior_probability): raise ValueError(f"Initial value {param_value} of parameter {prior_param_name} is invalid for prior {prior_function}.") else: # This prior is in default parameterisation - need to convert parameter value # Get the current parameter value and convert to default default_param_value = self._convert_single_param_to_default(prior_param_name) log_prior_probability = prior_function(default_param_value) if not np.isfinite(log_prior_probability): raise ValueError(f"Initial value {default_param_value} of parameter {prior_param_name} (in default parameterisation) is invalid for prior {prior_function}.")
[docs] def _convert_single_param_to_default(self, default_param_name: str) -> float: """Convert a single parameter from current to default parameterisation.""" # Extract planet letter if this is a planetary parameter if '_' in default_param_name: base_param, planet_letter = default_param_name.rsplit('_', 1) if planet_letter in self.planet_letters: # Get all current parameters for this planet (we need all five parameters to do a conversion) planet_params_dict = {} for par_name in self.parameterisation.pars: param_key = f"{par_name}_{planet_letter}" planet_params_dict[par_name] = self.params[param_key].value # Convert all the planetary parameters to the default parameterisation default_planet_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params_dict) # Return just the requested parameter in the default parameterisation return default_planet_params[base_param] # For non-planetary parameters (g, gd, gdd, jit), they're the same in all parameterisations if default_param_name in self.params: return self.params[default_param_name].value raise ValueError(f"Cannot convert parameter {default_param_name} to default parameterisation")
@property def free_params_dict(self) -> Dict[str, Parameter]: """Free parameters as dict.""" free_pars = {} for par in self.params: if self.params[par].fixed is False: free_pars[par] = self.params[par] return free_pars @property def free_params_values(self) -> list[float]: """Values of free parameters as list.""" return [param.value for param in self.free_params_dict.values()] @property def free_params_names(self) -> list[str]: """Names of free parameters as list.""" return list(self.free_params_dict.keys()) @property def fixed_params_dict(self) -> Dict[str, Parameter]: """Fixed parameters as dict, mapping names to Parameter objects.""" fixed_pars = {} for par in self.params: if self.params[par].fixed is True: fixed_pars[par] = self.params[par] return fixed_pars @property def fixed_params_values(self) -> list[float]: """Values of fixed parameters, as list.""" return [param.value for param in self.fixed_params_dict.values()] @property def fixed_params_names(self) -> list[str]: """Names of fixed parameters, as list.""" return list(self.fixed_params_dict.keys()) @property def fixed_params_values_dict(self) -> Dict[str, float]: """Fixed parameters as dict mapping names to just the values.""" return dict(zip(self.fixed_params_names, self.fixed_params_values))
[docs] def find_map_estimate(self, method: str = "Powell") -> scipy.optimize.OptimizeResult: """Find Maximum A Posteriori (MAP) estimate of parameters. Parameters ---------- method : str, optional Optimization method to use (default: "Powell") Returns ------- scipy.optimize.OptimizeResult The optimization result containing the MAP estimate Raises ------ Warning If MAP optimization fails to converge """ # Initialize log-posterior object lp = LogPosterior( self.planet_letters, self.parameterisation, self.priors, self.fixed_params_values_dict, self.free_params_names, self.time, self.vel, self.velerr, self.instrument, self.unique_instruments, self.t0, ) initial_guess = self.free_params_values if len(initial_guess) == 0: raise ValueError( "Cannot run MAP optimisation: no free parameters to optimise. " "At least one parameter must be set as free (fixed=False) before calling find_map_estimate()." ) # Perform MAP optimization def negative_log_posterior(*args: float) -> float: return lp._negative_log_probability_for_MAP(*args) map_results = minimize(negative_log_posterior, initial_guess, method=method) if map_results.success is False: print(map_results) warnings.warn("MAP did not succeed. Check the initial values of the parameters, and the prior functions.") # Print results as dictionary (to show param names too) map_results_dict = dict(zip(self.free_params_names, map_results.x)) print("MAP parameter results:", map_results_dict) # Return the scipy OptimizeResult object so that user can inspect fully if needed return map_results
[docs] def generate_initial_walker_positions_random(self, nwalkers: int, verbose: bool = False, max_attempts: int = 1000) -> np.ndarray: """Generate random initial walker positions that satisfy priors and are astrophysically valid. Creates random starting positions for MCMC walkers by sampling from appropriate distributions based on each parameter's prior type. Ensures that parameter combinations are astrophysically valid (e.g., eccentricity < 1). Parameters ---------- nwalkers : int Number of MCMC walkers to generate positions for verbose : bool, default False If True, print walker positions during generation max_attempts : int, default 1000 Maximum attempts to generate a valid walker position Returns ------- np.ndarray Array of shape (nwalkers, ndim) where ndim is the number of free parameters. Each row represents the starting position for one walker in the order of free_params_names. Raises ------ ValueError If a prior type is not supported for walker generation or if unable to generate valid positions after max_attempts Examples -------- >>> # Generate positions for 40 walkers >>> nwalkers = 10 * len(fitter.free_params_names) >>> initial_positions = fitter.generate_initial_walker_positions_random(nwalkers) >>> fitter.run_mcmc(initial_positions, nwalkers, max_steps=2000) """ if len(self.free_params_values) == 0: raise ValueError( "Cannot generate walker positions: no free parameters to sample. " "At least one parameter must be set as free (fixed=False)." ) if verbose: print("Free parameters:", self.free_params_names) mcmc_init = [] for walker_idx in range(nwalkers): attempts = 0 while attempts < max_attempts: walker_position = [] for param_name in self.free_params_names: # Check if we have a direct prior for this parameter # (because user may be fitting in a transformed parameterisation, but gave priors in the default parameterisation instead) if param_name in self.priors: prior = self.priors[param_name] if isinstance(prior, ravest.prior.Normal): walker_position.append(np.random.normal(loc=prior.mean, scale=2*prior.std)) elif isinstance(prior, ravest.prior.HalfNormal): walker_position.append(np.abs(np.random.normal(loc=0, scale=2*prior.std))) elif isinstance(prior, ravest.prior.Uniform): walker_position.append(np.random.uniform(low=prior.lower, high=prior.upper)) elif isinstance(prior, ravest.prior.TruncatedNormal): walker_position.append(np.random.uniform(low=prior.lower, high=prior.upper)) elif isinstance(prior, ravest.prior.Beta): walker_position.append(np.random.uniform(low=0, high=1)) elif isinstance(prior, ravest.prior.EccentricityUniform): walker_position.append(np.random.uniform(low=0, high=prior.upper)) else: raise ValueError(f"Unsupported prior type for walker generation: {type(prior)}") else: # No direct prior for this parameter (this happens if fitting in a transformed parameterisation, but prior is in Default) # Instead use current value + small perturbation centre_val = self.params[param_name].value # Add small random perturbation (10% of current value + small fixed amount for near-zero values) perturbation = np.random.normal(0, abs(centre_val) * 0.1 + 0.01) walker_position.append(centre_val + perturbation) # Check astrophysical validity and prior compliance try: # Convert walker position to full parameter dict (free + fixed) free_params_dict = dict(zip(self.free_params_names, walker_position)) all_params_dict = self.fixed_params_values_dict | free_params_dict # Check astrophysical validity self._validate_astrophysical_validity(all_params_dict) # Check prior compliance using LogPosterior, rather than calling priors direct # (because it handles Transformed->Default parameter transformations already, if needed) lp = LogPosterior( self.planet_letters, self.parameterisation, self.priors, self.fixed_params_values_dict, self.free_params_names, self.time, self.vel, self.velerr, self.instrument, self.unique_instruments, self.t0, ) # Check the log-prior probability is finite (i.e. proposed initial values are within prior bounds) params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict) log_prior = lp.log_prior(params_for_prior) if not np.isfinite(log_prior): raise ValueError(f"Outside prior bounds (log_prior = {log_prior})") # If both astrophysical and priors validations pass, we have a valid walker position break except ValueError: # Validation failed. Generate a new set of values and try again. attempts += 1 continue if attempts >= max_attempts: raise ValueError(f"Could not generate astrophysically valid walker {walker_idx} after {max_attempts} attempts. " f"Consider relaxing priors or checking parameter constraints.") if verbose: print(f"Walker {walker_idx} position: {walker_position} (valid after {attempts + 1} attempts)") mcmc_init.append(walker_position) mcmc_init = np.array(mcmc_init) if verbose: print(f"Generated MCMC initial positions with shape: {mcmc_init.shape}") return mcmc_init
[docs] def generate_initial_walker_positions_around_point( self, centre: np.ndarray | list, nwalkers: int, scale: float = 1e-4, relative: bool = True, verbose: bool = False, max_attempts: int = 1000 ) -> np.ndarray: """Generate initial walker positions in a ball around a supplied centre point. Creates starting positions for MCMC walkers clustered around a centre point (e.g., MAP estimate). Each walker is generated by adding small random perturbations to the centre values. Validates that both the centre point and all generated walker positions satisfy priors and are astrophysically valid. Parameters ---------- centre : np.ndarray or list Centre point for walker positions. Must have length equal to the number of free parameters and be in the order of free_params_names. nwalkers : int Number of MCMC walkers to generate positions for scale : float, default 1e-4 Scale of perturbations around centre point relative : bool, default True If True, perturbations scale with parameter values (scale * centre * random). If False, perturbations are absolute (scale * random). verbose : bool, default False If True, print walker positions during generation max_attempts : int, default 1000 Maximum attempts to generate a valid walker position Returns ------- np.ndarray Array of shape (nwalkers, ndim) where ndim is the number of free parameters. Each row represents the starting position for one walker in the order of free_params_names. Raises ------ ValueError If centre has wrong length, if centre point is invalid, or if unable to generate valid positions after max_attempts Examples -------- >>> # Generate walkers around MAP estimate >>> map_result = fitter.find_map_estimate() >>> initial_positions = fitter.generate_initial_walker_positions_around_point( ... centre=map_result.x, nwalkers=40, scale=1e-4 ... ) >>> fitter.run_mcmc(initial_positions, nwalkers=40, max_steps=2000) """ if len(self.free_params_values) == 0: raise ValueError( "Cannot generate walker positions: no free parameters to sample. " "At least one parameter must be set as free (fixed=False)." ) centre = np.asarray(centre) if len(centre) != len(self.free_params_names): raise ValueError( f"Centre must have length {len(self.free_params_names)} " f"(number of free parameters), got {len(centre)}" ) if verbose: print("Free parameters:", self.free_params_names) print(f"Centre values: {centre}") # Validate centre point first try: free_params_dict = dict(zip(self.free_params_names, centre)) all_params_dict = self.fixed_params_values_dict | free_params_dict # Check astrophysical validity self._validate_astrophysical_validity(all_params_dict) # Check prior compliance lp = LogPosterior( self.planet_letters, self.parameterisation, self.priors, self.fixed_params_values_dict, self.free_params_names, self.time, self.vel, self.velerr, self.instrument, self.unique_instruments, self.t0, ) params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict) log_prior = lp.log_prior(params_for_prior) if not np.isfinite(log_prior): raise ValueError(f"Centre point outside prior bounds (log_prior = {log_prior})") if verbose: print(f"Centre point validated (log_prior = {log_prior})") except ValueError as e: raise ValueError(f"Supplied centre point is not valid: {e}") # Generate walker positions around centre mcmc_init = [] if verbose and relative and np.any(centre == 0.0): zero_names = [self.free_params_names[i] for i in range(len(centre)) if centre[i] == 0.0] print(f"Note: centre value is exactly 0.0 for {zero_names}; " f"using absolute perturbation (scale={scale}) for these parameters.") for walker_idx in range(nwalkers): attempts = 0 while attempts < max_attempts: # Generate perturbation random_vals = np.random.randn(len(centre)) if relative: # Relative perturbation: scales with parameter values. # When a centre value is exactly 0.0, the relative # perturbation (scale * randn * |0|) is always zero, # producing identical walker values in that dimension. # This causes emcee to reject the walkers as linearly # dependent (condition number check). Fall back to # absolute perturbation for those parameters. perturbation = np.empty(len(centre)) for i in range(len(centre)): if centre[i] == 0.0: perturbation[i] = scale * random_vals[i] else: perturbation[i] = scale * random_vals[i] * np.abs(centre[i]) else: # Absolute perturbation: same scale for all parameters perturbation = scale * random_vals walker_position = centre + perturbation # Validate this walker position try: free_params_dict = dict(zip(self.free_params_names, walker_position)) all_params_dict = self.fixed_params_values_dict | free_params_dict # Check astrophysical validity self._validate_astrophysical_validity(all_params_dict) # Check prior compliance params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict) log_prior = lp.log_prior(params_for_prior) if not np.isfinite(log_prior): raise ValueError(f"Outside prior bounds (log_prior = {log_prior})") # If validation passes, we have a valid walker position break except ValueError: # Validation failed, try again attempts += 1 continue if attempts >= max_attempts: raise ValueError( f"Could not generate astrophysically valid walker {walker_idx} after {max_attempts} attempts. " f"Consider using a larger scale parameter or checking that the centre point is not too close to prior/physical boundaries." ) if verbose: print(f"Walker {walker_idx} position: {walker_position} (valid after {attempts + 1} attempts)") mcmc_init.append(walker_position) mcmc_init = np.array(mcmc_init) if verbose: print(f"Generated MCMC initial positions with shape: {mcmc_init.shape}") return mcmc_init
[docs] def generate_initial_walker_positions_from_map( self, map_result: scipy.optimize.OptimizeResult, nwalkers: int, scale: float = 1e-4, relative: bool = True, verbose: bool = False, max_attempts: int = 1000 ) -> np.ndarray: """Generate initial walker positions around MAP estimate. Convenience function that generates walker positions clustered around MAP parameter estimates from a pre-computed MAP result. Parameters ---------- map_result : scipy.optimize.OptimizeResult Result from find_map_estimate() nwalkers : int Number of MCMC walkers to generate positions for scale : float, default 1e-4 Scale of perturbations around MAP values relative : bool, default True If True, perturbations scale with parameter values. If False, perturbations are absolute. verbose : bool, default False If True, print walker positions during generation max_attempts : int, default 1000 Maximum attempts to generate a valid walker position Returns ------- np.ndarray Array of shape (nwalkers, ndim) where ndim is the number of free parameters. Each row represents the starting position for one walker. Raises ------ ValueError If unable to generate valid positions Examples -------- >>> # Find MAP then generate walkers around it >>> map_result = fitter.find_map_estimate() >>> initial_positions = fitter.generate_initial_walker_positions_from_map( ... map_result=map_result, nwalkers=40 ... ) >>> fitter.run_mcmc(initial_positions, nwalkers=40, max_steps=2000) """ if len(self.free_params_values) == 0: raise ValueError( "Cannot generate walker positions: no free parameters to sample. " "At least one parameter must be set as free (fixed=False)." ) return self.generate_initial_walker_positions_around_point( centre=map_result.x, nwalkers=nwalkers, scale=scale, relative=relative, verbose=verbose, max_attempts=max_attempts )
[docs] def run_mcmc(self, initial_positions : np.ndarray, nwalkers: int, max_steps: int = 5000, progress: bool = True, multiprocessing: bool = False, check_convergence: bool = False, convergence_check_interval: int = 1000, convergence_check_start: int = 0) -> None: """Run MCMC sampling from given initial walker positions. Parameters ---------- initial_positions : np.ndarray Starting positions for all MCMC walkers. Shape must be (nwalkers, ndim) where ndim is the number of free parameters. Each row represents the starting position for one walker in the order of free_params_names. nwalkers : int Number of MCMC walkers (must match first dimension of initial_positions ) max_steps : int, optional Maximum number of MCMC steps to run. If check_convergence=False, runs for exactly this many steps. If check_convergence=True, runs up to this many steps, stopping early when convergence criteria are met (default: 5000) progress : bool, optional Whether to show progress bar during MCMC (default: True) multiprocessing : bool, optional Whether to use multiprocessing for MCMC (default: False) check_convergence : bool, optional If True, check for convergence and stop early when criteria met. Convergence requires: chain length > 50 times max autocorrelation time, and autocorrelation time estimate stable to 1 percent. If False, run for exactly max_steps (default: False) convergence_check_interval : int, optional Steps between convergence checks (only used if check_convergence=True) (default: 1000) convergence_check_start : int, optional Minimum iteration before starting convergence checks. Set this sensibly (e.g. 2x burn-in) to avoid inaccurate tau estimates on short chains (default: 0) """ if len(self.free_params_values) == 0: raise ValueError( "Cannot run MCMC: no free parameters to sample. " "At least one parameter must be set as free (fixed=False)." ) # Initialize log-posterior object for MCMC sampling lp = LogPosterior( self.planet_letters, self.parameterisation, self.priors, self.fixed_params_values_dict, self.free_params_names, self.time, self.vel, self.velerr, self.instrument, self.unique_instruments, self.t0, ) # Enforce minimum number of walkers (though users ideally should have many more than this) if nwalkers < 2 * self.ndim: logging.warning(f"nwalkers should be at least 2 * ndim. You have {nwalkers} walkers and {self.ndim} dimensions. Setting nwalkers to {2 * self.ndim}.") self.nwalkers = 2 * self.ndim else: self.nwalkers = nwalkers # Validate walker positions shape if initial_positions .shape != (nwalkers, self.ndim): raise ValueError(f"initial_positions must have shape ({nwalkers}, {self.ndim}), got {initial_positions .shape}") # Validate every walker position for astrophysical validity and prior compliance # (we don't want to start any chains in invalid parameter space) for i, walker_position in enumerate(initial_positions ): walker_params_dict = dict(zip(self.free_params_names, walker_position)) all_params_dict = self.fixed_params_values_dict | walker_params_dict # Check astrophysical validity try: self._validate_astrophysical_validity(all_params_dict) except ValueError as e: raise ValueError(f"Walker {i} has invalid astrophysical parameters: {e}") from e # Check prior compliance params_for_prior = lp._convert_params_for_prior_evaluation(walker_params_dict) log_prior = lp.log_prior(params_for_prior) if not np.isfinite(log_prior): raise ValueError(f"Walker {i} is outside prior bounds (log_prior = {log_prior})") # TODO: parameter_names argument does slightly impact performance - but not sure if it can be avoided, we do need the names # and I'm not sure constructing the dictionary later ourselves manually is any quicker than passing parameter_names argument # Create sampler if multiprocessing: pool = mp.get_context("spawn").Pool() # Use 'spawn' instead of 'fork' to avoid issues on some Linux platforms sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, lp.log_probability, parameter_names=self.free_params_names, pool=pool) else: sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, lp.log_probability, parameter_names=self.free_params_names) # Warn if convergence arguments provided but convergence checking disabled if not check_convergence: if convergence_check_interval != 1000 or convergence_check_start != 0: logging.warning( "Convergence checking arguments provided but check_convergence=False. " "These arguments will be ignored. Did you forget to set check_convergence=True?" ) # Run MCMC with or without convergence checking if not check_convergence: # Fixed-length mode - run for exactly max_steps logging.info(f"Starting MCMC for {max_steps} steps...") sampler.run_mcmc(initial_state=initial_positions, nsteps=max_steps, progress=progress) logging.info("...MCMC done.") else: # Convergence checking - run up to max_steps, stopping early if converged logging.info(f"Starting MCMC with convergence checks. (Maximum {max_steps} steps, checking convergence every {convergence_check_interval} steps after iteration {convergence_check_start})...") # Initialize autocorrelation history storage self.autocorr_history = {} old_tau = np.inf for sample in sampler.sample(initial_state=initial_positions, iterations=max_steps, progress=progress): # Only check at specified intervals if sampler.iteration % convergence_check_interval != 0: continue # Don't check before we have reached convergence_check_start if sampler.iteration < convergence_check_start: continue # Get autocorrelation time estimate tau = sampler.get_autocorr_time(tol=0) # Store autocorrelation history for plotting/diagnostics later self.autocorr_history[sampler.iteration] = tau.copy() # Log progress logging.info(f"Convergence check: Step {sampler.iteration}: mean(tau)={np.mean(tau):.1f}, max(tau)={np.max(tau):.1f}") # Check convergence criteria check_chain_length = np.all(sampler.iteration > 50 * tau) # Chain length > 50 * tau check_stable_tau = np.all(np.abs(old_tau - tau) / tau < 0.01) # Tau stable to 1 percent converged = check_chain_length and check_stable_tau if converged: logging.info(f"Converged at iteration {sampler.iteration}") break else: logging.info(f"Not yet converged (N/50>tau check: {check_chain_length}, tau stability check: {check_stable_tau})") # Warn if approaching max steps without convergence if sampler.iteration > 0.8 * max_steps: logging.warning(f"Approaching max iterations ({max_steps}) without convergence! (max tau={np.max(tau):.1f}, tau stability change={np.abs(old_tau - tau) / tau})") # Update old tau for next check old_tau = tau # Final log final_steps = sampler.iteration logging.info(f"MCMC complete: {final_steps} steps total") # Close multiprocessing pool if used if multiprocessing: pool.close() pool.join() self.sampler = sampler
[docs] def get_samples_np(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, flat: bool = False) -> np.ndarray: """Return a contiguous numpy array of MCMC samples. Samples can be discarded from the start and/or the end of the array. You can also thin (take only every n-th sample), and you can flatten the array so that each walker's chain is merged into one chain. This is the foundational method for accessing MCMC samples. All the other get_samples methods build on this. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) flat : bool, optional Whether to flatten each walker's chain into one chain. (default: False) If True, return flattened array with shape (nsteps_after_discard_thin * nwalkers, ndim) If False, return unflattened array with shape (nsteps_after_discard_thin, nwalkers, ndim) Returns ------- np.ndarray Contiguous array of MCMC samples. Shape depends on `flat` parameter: - flat=False: (nsteps_after_discard_thin, nwalkers, ndim) - flat=True: (nsteps_after_discard_thin * nwalkers, ndim) Notes ----- We enforce np.ascontiguousarray() on the return, because np.reshape() does not guarantee a contiguous array in memory. """ # Get the full chain from emcee without any processing full_samples = self.sampler.get_chain(discard=0, thin=1, flat=False) # Match emcee's slicing logic: [discard + thin - 1 : end : thin] # But adapted - we also allow for discarding from the end start_idx = discard_start + thin - 1 if discard_end == 0: end_idx = full_samples.shape[0] else: end_idx = full_samples.shape[0] - discard_end # Check the start and end points are valid if start_idx >= end_idx: raise ValueError(f"Invalid parameters: start_idx ({start_idx}) >= end_idx ({end_idx}). " f"Try reducing discard_start ({discard_start}), discard_end ({discard_end}), or thin ({thin}).") # Apply the slicing samples = full_samples[start_idx:end_idx:thin] # Flatten if requested (after discarding) - flatten steps and walkers into single dimension if flat: # (steps, walkers, ndim) -> (steps*walkers, ndim) nsteps, nwalkers, ndim = samples.shape samples = samples.reshape(nsteps * nwalkers, ndim) return np.ascontiguousarray(samples)
[docs] def get_samples_df(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> pd.DataFrame: """Return a pandas DataFrame of flattened MCMC samples. Each row represents one sample, each column represents one parameter. Built on get_samples_np(). Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) Returns ------- pd.DataFrame DataFrame with shape (nsteps_after_discard_thin * nwalkers, ndim). Columns are parameter names. """ flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) return pd.DataFrame(flat_samples, columns=self.free_params_names)
[docs] def get_samples_dict(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> Dict[str, np.ndarray]: """Return a dict of flattened MCMC samples. Each parameter gets a 1D (flattened) contiguous array of all its samples. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) Returns ------- dict Dictionary mapping parameter names to 1D arrays of samples. Each array has shape (nsteps_after_discard_thin * nwalkers,) Examples -------- >>> samples_dict = fitter.get_samples_dict(discard_start=1000) >>> K_b_samples = samples_dict['K_b'] # All samples for parameter K for planet b """ flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) param_names = self.free_params_names # Direct numpy slicing - much faster than pandas operations return {name: flat_samples[:, i] for i, name in enumerate(param_names)}
[docs] def get_sampler_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, flat: bool = False) -> np.ndarray: """Returns the log probability at each step of the sampler. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) flat : bool, optional If True, return flattened array shape (nsteps_after_discard_thin * nwalkers) If False, return unflattened array shape (nsteps_after_discard_thin, nwalkers) (default: False) Returns ------- np.ndarray Array of log probabilities of the function at each sample. """ # Get the full log prob chain from emcee without any processing full_lnprob = self.sampler.get_log_prob(discard=0, thin=1, flat=False) # Match emcee's slicing logic: [discard + thin - 1 : end : thin] # But adapted - we also allow for discarding from the end start_idx = discard_start + thin - 1 if discard_end == 0: end_idx = full_lnprob.shape[0] else: end_idx = full_lnprob.shape[0] - discard_end # Check the start and end points are valid if start_idx >= end_idx: raise ValueError(f"Invalid parameters: start_idx ({start_idx}) >= end_idx ({end_idx}). " f"Try reducing discard_start ({discard_start}), discard_end ({discard_end}), or thin ({thin}).") # Apply the slicing lnprob = full_lnprob[start_idx:end_idx:thin] # Flatten if requested (after discarding) - flatten steps and walkers into single dimension if flat: # (steps, walkers) -> (steps*walkers,) nsteps, nwalkers = lnprob.shape lnprob = lnprob.reshape(nsteps * nwalkers) return np.ascontiguousarray(lnprob)
[docs] def get_mcmc_posterior_dict(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> dict: """Return dict combining MCMC samples for free params, and the fixed values for the fixed params. This method creates a unified dictionary containing all model parameters: fixed parameters as single float values, and free parameters as arrays of MCMC samples. This format is ideal for functions like calculate_mpsini that need all parameters (whether free or fixed), and that should propagate uncertainties from the free parameters samples. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) Returns ------- dict Dictionary of all parameters: - Fixed parameters: single float values - Free parameters: 1D arrays of MCMC samples with shape (nsteps_after_discard_thin * nwalkers,) """ fixed_params_dict = self.fixed_params_values_dict free_samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin) return fixed_params_dict | free_samples_dict
[docs] def calculate_log_likelihood(self, params_dict: Dict[str, float]) -> float: """Calculate log-likelihood for given parameter values. Note this does not include (log-)prior probabilities, this is just the (log-) *likelihood* primarily for use in AICc & BIC calculation. Parameters ---------- params_dict : dict Dictionary of all parameter values (both fixed and free parameters) Returns ------- float The log-likelihood value """ # Create LogLikelihood object (same as in find_map_estimate and run_mcmc) log_likelihood = LogLikelihood( time=self.time, vel=self.vel, velerr=self.velerr, instrument=self.instrument, unique_instruments=self.unique_instruments, t0=self.t0, planet_letters=self.planet_letters, parameterisation=self.parameterisation, ) return log_likelihood(params_dict)
[docs] def build_params_dict(self, free_params: np.ndarray | list | Dict[str, float]) -> Dict[str, float]: """Build a params dict by providing free param vals, combine with fixed param vals. Takes free parameter float values (which can be from any source e.g. MAP results, MCMC posteriors, or any custom values) and combines them with the fixed parameter values to create a complete parameter dictionary. This dict is ideal for calculating chi2, log-likelihood, AICc, and BIC. This is designed for a single value per parameter. For combining the MCMC posterior chains for free parameters and the fixed values for fixed parameters, use `get_mcmc_posterior_dict` method. Parameters ---------- free_params : list, np.ndarray, or dict Free parameter values from any source: - list/array: values in order of self.free_params_names - dict: mapping of free param names to values Returns ------- Dict[str, float] Complete parameters dict with both free and fixed parameter values Examples -------- >>> # From MAP optimization result >>> map_result = fitter.find_map_estimate() >>> params = fitter.build_params_dict(map_result.x) >>> aicc = fitter.calculate_aicc(params) >>> >>> # From best MCMC sample >>> best_sample = fitter.get_sample_with_best_lnprob(discard_start=1000) >>> params = fitter.build_params_dict(best_sample) >>> bic = fitter.calculate_bic(params) >>> >>> # From custom array of values (in order of free_params_names) >>> custom_values = [5.0, 50.0, 0.1, 0.0, 2450000.0] # example values >>> params = fitter.build_params_dict(custom_values) >>> log_like = fitter.calculate_log_likelihood(params) """ if isinstance(free_params, dict): # Validate that all expected free parameters are present expected_names = set(self.free_params_names) provided_names = set(free_params.keys()) missing = expected_names - provided_names if missing: raise ValueError(f"Missing required free parameters: {missing}") extra = provided_names - expected_names if extra: raise ValueError(f"Unexpected parameters provided: {extra}") return self.fixed_params_values_dict | free_params else: # Validate that array/list has correct length if len(free_params) != len(self.free_params_names): raise ValueError( f"Expected {len(self.free_params_names)} free parameter values " f"but got {len(free_params)} " f"(expecting {len(self.free_params_names)} values for {self.free_params_names})" ) free_dict = dict(zip(self.free_params_names, free_params)) return self.fixed_params_values_dict | free_dict
[docs] def calculate_chi2(self, params_dict: Dict[str, float]) -> float: r"""Calculate chi-squared for given parameter values. Uses LogLikelihood to avoid code duplication. Works backwards from log-likelihood: .. math:: \ell = -\frac{1}{2} \left( \chi^2 + \text{penalty} \right) Parameters ---------- params_dict : dict Dictionary of all parameter values (both fixed and free parameters) Returns ------- float Chi-squared value: .. math:: \chi^2 = \sum_i \frac{(d_i - m_i)^2}{\sigma_i^2 + \sigma_{\text{jit}}^2} """ # Create LogLikelihood instance to reuse RV model calculation ll = LogLikelihood( self.time, self.vel, self.velerr, self.instrument, self.unique_instruments, self.t0, self.planet_letters, self.parameterisation ) # Get log-likelihood log_like = ll(params_dict) # Work backwards to get chi2 # ll = -0.5 * (chi2 + penalty_term) # chi2 = -2 * ll - penalty_term # Calculate per-instrument jitter for each observation velerr_jitter_squared = np.zeros_like(self.velerr) for inst in self.unique_instruments: mask = (self.instrument == inst) jit = params_dict[f"jit_{inst}"] velerr_jitter_squared[mask] = self.velerr[mask]**2 + jit**2 penalty_term = np.sum(np.log(2 * np.pi * velerr_jitter_squared)) chi2 = -2 * log_like - penalty_term return chi2
[docs] def calculate_aicc(self, params_dict: Dict[str, float]) -> float: r"""Calculate corrected Akaike Information Criterion (AICc). .. math:: \text{AICc} = 2k - 2\ln\mathcal{L} + \frac{2k^2 + 2k}{n - k - 1} where :math:`k` is the number of free parameters, :math:`n` is the number of observations, and :math:`\mathcal{L}` is the likelihood. Converges to AIC for large :math:`n`. Parameters ---------- params_dict : dict Dictionary of all parameter values (both fixed and free parameters) Returns ------- float AICc value """ k = self.ndim n = len(self.time) log_like = self.calculate_log_likelihood(params_dict) aic = 2 * k - 2 * log_like # traditional AIC correction = (2 * k**2 + 2 * k) / (n - k - 1) # small-sample correction return aic + correction
[docs] def calculate_bic(self, params_dict: Dict[str, float]) -> float: r"""Calculate Bayesian Information Criterion (BIC) for given parameters. .. math:: \text{BIC} = k \ln n - 2 \ln \mathcal{L} where :math:`k` is the number of free parameters, :math:`n` is the number of observations, and :math:`\mathcal{L}` is the likelihood. Parameters ---------- params_dict : dict Dictionary of all parameter values (both fixed and free parameters) Returns ------- float BIC value """ log_like = self.calculate_log_likelihood(params_dict) return self.ndim * np.log(len(self.time)) - 2 * log_like
[docs] def get_sample_with_best_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> Dict[str, float]: """Get parameter values from the MCMC sample with the highest log probability. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) Returns ------- Dict[str, float] Dictionary of parameter names to values from the best sample """ # Get samples and log probabilities samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) lnprob = self.get_sampler_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) # Find index of maximum log probability best_idx = np.argmax(lnprob) best_lnprob = lnprob[best_idx] print(f"Best sample found with log probability {best_lnprob:.6f} at index {best_idx} of samples (with discard_start={discard_start}, discard_end={discard_end}, thin={thin})") # Get parameter values at that index best_values = samples[best_idx] # Return as dictionary return dict(zip(self.free_params_names, best_values))
[docs] def plot_autocorr_estimates( self, params: list[str] | None = None, plot_mean: bool = False, show_legend: bool = True, title: str | None = "Autocorrelation Time Estimates", xlabel: str | None = "Step number", ylabel: str | None = r"Autocorrelation time $\tau$", save: bool = False, fname: str = "autocorr_plot.png", dpi: int = 100 ) -> None: r"""Plot autocorrelation time estimates from adaptive MCMC run. Shows how autocorrelation time evolved during the MCMC run and the convergence threshold line (N / 50). Only available if run_mcmc was called with check_convergence=True. Parameters ---------- params : list[str] or None, optional List of parameter names to plot. If None, plots all free parameters (default: None) plot_mean : bool, optional If True, plot mean tau instead of individual parameter taus. Overrides params argument (default: False) show_legend : bool, optional Whether to show legend (default: True) title : str or None, optional Plot title (default: "Autocorrelation Time Estimates"). Set to None or "" to skip. xlabel : str or None, optional X-axis label (default: "Step number"). Set to None or "" to skip. ylabel : str or None, optional Y-axis label (default: r"Autocorrelation time $\tau$"). Set to None or "" to skip. save : bool, optional Save the plot to path `fname` (default: False) fname : str, optional The path to save the plot to (default: "autocorr_plot.png") dpi : int, optional The dpi to save the image at (default: 100) Raises ------ ValueError If no autocorrelation history is available (run_mcmc was not called with check_convergence=True, or has not been called yet) """ # Check if data available if not hasattr(self, 'autocorr_history') or len(self.autocorr_history) == 0: raise ValueError( "No autocorrelation history available. " "Please run run_mcmc() with check_convergence=True first." ) iterations = np.array(list(self.autocorr_history.keys())) max_iteration = np.max(iterations) tau_history = np.array(list(self.autocorr_history.values())) # Shape: (n_checks, n_params) # Create plot fig, ax = plt.subplots(1, figsize=(10, 6)) if title: fig.suptitle(title) # Plot convergence threshold (N/50) ax.plot([0, max_iteration], [0, max_iteration / 50], "--k", linewidth=2, label="N/50") if plot_mean: # Plot mean tau mean_tau = np.mean(tau_history, axis=1) ax.plot(iterations, mean_tau, linewidth=2, label="Mean τ") else: # Determine which parameters to plot if params is None: params_to_plot = self.free_params_names indices_to_plot = range(len(self.free_params_names)) else: params_to_plot = [] indices_to_plot = [] for param in params: if param in self.free_params_names: idx = self.free_params_names.index(param) params_to_plot.append(param) indices_to_plot.append(idx) else: logging.warning(f"Parameter '{param}' not found in free parameters, skipping") # Plot individual parameter taus for i, param_name in zip(indices_to_plot, params_to_plot): ax.plot(iterations, tau_history[:, i], alpha=0.7, label=param_key_to_latex(param_name)) ax.set_xlim(0, iterations.max()) ax.set_ylim(bottom=0) if xlabel: ax.set_xlabel(xlabel) if ylabel: ax.set_ylabel(ylabel) if show_legend: ax.legend(loc='upper left') ax.grid(True, alpha=0.3) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def plot_chains(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, truths: list = None, title: str | None = "Chains plot", xlabel: str | None = "Step number", save: bool = False, fname: str = "chains_plot.png", dpi: int = 100) -> None: """Plot MCMC chains for all free parameters. Displays the evolution of each free parameter across MCMC steps for all walkers. Useful for diagnosing convergence, burn-in, and mixing of the MCMC chains. Each parameter gets its own subplot showing all walker traces. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) truths : list, optional List of true parameter values to overplot as horizontal lines. Must match the number of free parameters. Use None for parameters without known truth values (default: None) title : str or None, optional Plot title (default: "Chains plot"). Set to None or "" to skip. xlabel : str or None, optional X-axis label (default: "Step number"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "chains_plot.png") dpi : int, optional Resolution for saving (default: 100) """ # Scale figure height to maintain consistent subplot size subplot_height_inches = 1.25 fig, axes = plt.subplots(self.ndim, figsize=(10, self.ndim * subplot_height_inches), sharex=True, constrained_layout=True) if title: fig.suptitle(title) if self.ndim == 1: axes = [axes] if truths is not None: if not len(truths) == self.ndim: raise ValueError(f"Length of truths ({len(truths)}) must match number of free parameters ({self.ndim})") samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=False) for i in range(self.ndim): ax = axes[i] ax.set_xlim(0, len(samples)) ax.set_ylabel(param_key_to_latex(self.free_params_names[i])) to_plot = samples[:, :, i] ax.plot(to_plot, "k", alpha=0.3) if truths is not None and truths[i] is not None: ax.axhline(truths[i], color="tab:blue") fig.align_ylabels(axes) if xlabel: axes[-1].set_xlabel(xlabel) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def plot_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = "Log Probability Traces", xlabel: str | None = "Step number", ylabel: str | None = "Log probability", save: bool = False, fname: str = "lnprob_plot.png", dpi: int = 100) -> None: """Plot log probability traces for all walkers. Useful for diagnosing MCMC convergence and identifying problematic walkers/parameters. You can use `discard_start` and `discard_end` to focus in on specific steps in the chains. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) title : str or None, optional Plot title (default: "Log Probability Traces"). Set to None or "" to skip. xlabel : str or None, optional X-axis label (default: "Step number"). Set to None or "" to skip. ylabel : str or None, optional Y-axis label (default: "Log probability"). Set to None or "" to skip. save : bool, optional Save the plot to path `fname` (default: False) fname : str, optional The path to save the plot to (default: "lnprob_plot.png") dpi : int, optional The dpi to save the image at (default: 100) """ fig, ax = plt.subplots(1, figsize=(10, 6)) if title: fig.suptitle(title) lnprobs = self.get_sampler_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=False) nsteps, nwalkers = lnprobs.shape for i in range(nwalkers): to_plot = lnprobs[:, i] ax.plot(to_plot, "k", alpha=0.3) ax.set_xlim(0, nsteps) if xlabel: ax.set_xlabel(xlabel) if ylabel: ax.set_ylabel(ylabel) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def plot_corner(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, plot_datapoints: bool = False, truths: list[float] = None, title: str | None = "Corner plots", save: bool = False, fname: str = "corner_plot.png", dpi: int = 100) -> None: """Create a corner plot of MCMC samples. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) plot_datapoints : bool, optional Show individual data points in addition to contours (default: False) truths : list of float, optional True parameter values to overplot as vertical/horizontal lines (default: None). Must match the order of free parameters if provided. title : str or None, optional Plot title (default: "Corner plots"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "corner_plot.png") dpi : int, optional Resolution for saving (default: 100) """ flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) param_labels = [param_key_to_latex(n) for n in self.free_params_names] fig = corner.corner( flat_samples, labels=param_labels, show_titles=True, plot_datapoints=plot_datapoints, quantiles=[0.1585, 0.5, 0.8415], truths=truths, ) if title: fig.suptitle(title) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def _plot_rv(self, params: Dict[str, float], title: str = "RV Model", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "rv_plot.png", dpi: int = 100) -> None: """Helper function to plot RV model with given parameters. Parameters ---------- params : dict Dictionary of parameter values (both free and fixed) title : str, optional Plot title (default: "RV Model"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "rv_plot.png") dpi : int, optional Resolution for saving (default: 100) Returns ------- np.ndarray Time array used for evaluation np.ndarray RV values at evaluation times """ # Create smooth time curve for plotting _tmin, _tmax = self.time.min(), self.time.max() _trange = _tmax - _tmin tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000) # Initialize arrays for planetary contributions rv_all_planets_smooth = np.zeros(len(tsmooth)) rv_all_planets_obs = np.zeros(len(self.time)) # Add planetary contributions for letter in self.planet_letters: planet_params = {} for par_name in self.parameterisation.pars: key = f"{par_name}_{letter}" planet_params[par_name] = params[key] planet = ravest.model.Planet(letter, self.parameterisation, planet_params) # Calculate for both time arrays in single loop rv_all_planets_smooth += planet.radial_velocity(tsmooth) rv_all_planets_obs += planet.radial_velocity(self.time) # Add trend contribution (no gamma offset - that's per-instrument) trend_params = {"gd": params["gd"], "gdd": params["gdd"]} trend = ravest.model.Trend(params=trend_params, t0=self.t0) rv_total_smooth = rv_all_planets_smooth + trend.radial_velocity(tsmooth) rv_total_obs = rv_all_planets_obs + trend.radial_velocity(self.time) # Subtract per-instrument gamma offsets from data (so we compare gamma-corrected data to physical model) vel_corrected = self.vel.copy() for inst in self.unique_instruments: mask = (self.instrument == inst) vel_corrected[mask] -= params[f"g_{inst}"] # Calculate per-instrument jitter for error bars velerr_with_jit = np.zeros_like(self.velerr) for inst in self.unique_instruments: mask = (self.instrument == inst) jit = params[f"jit_{inst}"] velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit**2) # Calculate residuals (gamma-corrected data minus model) residuals = vel_corrected - rv_total_obs # Create figure with subplots fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 5), gridspec_kw={'height_ratios': [3, 1], 'hspace': 0}) # Main RV plot - plot data by instrument with different colours colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)} for inst in self.unique_instruments: mask = (self.instrument == inst) ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=self.velerr[mask], marker=".", color=inst_colors[inst], ecolor=inst_colors[inst], linestyle="None", markersize=8, zorder=4, label=inst) # Jitter extension (faded, no label) ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=velerr_with_jit[mask], marker="None", ecolor=inst_colors[inst], linestyle="None", alpha=0.5, zorder=3) ax1.plot(tsmooth, rv_total_smooth, label="Model", color="black", zorder=2) ax1.set_xlim(tsmooth[0], tsmooth[-1]) if ylabel_main: ax1.set_ylabel(ylabel_main) if title: ax1.set_title(title) ax1.legend(loc="upper right") ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in') # Remove x-axis labels from top plot ax1.tick_params(axis='y', direction='in') # Set y-axis ticks automatically based on data range ax1.yaxis.set_major_locator(AutoLocator()) ax1.yaxis.set_minor_locator(AutoMinorLocator()) ax1.tick_params(axis='y', which='minor', direction='in', length=3) # Residuals plot for inst in self.unique_instruments: mask = (self.instrument == inst) ax2.errorbar(self.time[mask], residuals[mask], yerr=self.velerr[mask], marker=".", color=inst_colors[inst], ecolor=inst_colors[inst], linestyle="None", markersize=8, zorder=4) ax2.errorbar(self.time[mask], residuals[mask], yerr=velerr_with_jit[mask], marker="None", ecolor=inst_colors[inst], linestyle="None", alpha=0.5, zorder=3) ax2.axhline(0, color="k", linestyle="--", zorder=2) ax2.set_xlim(tsmooth[0], tsmooth[-1]) # Set symmetric y-limits for residuals max_abs_residual = np.max(np.abs(residuals + velerr_with_jit)) ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1) if xlabel: ax2.set_xlabel(xlabel) if ylabel_residuals: ax2.set_ylabel(ylabel_residuals) ax2.tick_params(axis='x', direction='in') ax2.tick_params(axis='y', direction='in') ax2.tick_params(axis='x', top=True, labeltop=False) # Add ticks on shared border # Set y-axis ticks automatically based on residuals range ax2.yaxis.set_major_locator(AutoLocator()) ax2.yaxis.set_minor_locator(AutoMinorLocator()) ax2.tick_params(axis='y', which='minor', direction='in', length=3) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def _plot_phase(self, planet_letter: str, params: Dict[str, float], title: str = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "phase_plot.png", dpi: int = 100) -> None: """Helper function to plot phase-folded RV model for a single planet with given parameters. Parameters ---------- planet_letter : str Letter identifying the planet to plot (e.g., 'b', 'c', 'd') params : dict Dictionary of parameter values (both free and fixed) title : str, optional Plot title (default: f"Planet {planet_letter} Phase Plot"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "phase_plot.png") dpi : int, optional Resolution for saving (default: 100) """ if title is None: title = f"Planet {planet_letter} Phase Plot" # get smooth linear time curve for plotting _tmin, _tmax = self.time.min(), self.time.max() _trange = _tmax - _tmin tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000) # Calculate per-instrument jitter for error bars velerr_with_jit = np.zeros_like(self.velerr) for inst in self.unique_instruments: mask = (self.instrument == inst) jit = params[f"jit_{inst}"] velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit**2) # Get period and time of conjunction for this planet P = params[f"P_{planet_letter}"] # Convert to tc if needed if "Tc" in self.parameterisation.pars: Tc = params[f"Tc_{planet_letter}"] elif "e" in self.parameterisation.pars and "w" in self.parameterisation.pars: _e = params[f"e_{planet_letter}"] _w = params[f"w_{planet_letter}"] _Tp = params[f"Tp_{planet_letter}"] Tc = self.parameterisation.convert_tp_to_tc(_Tp, P, _e, _w) else: # Fall back to default parameterisation conversion planet_params = {par: params[f"{par}_{planet_letter}"] for par in self.parameterisation.pars} default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params) Tc = self.parameterisation.convert_tp_to_tc(default_params["Tp"], P, default_params["e"], default_params["w"]) # Calculate phase-folded time arrays (in units of orbital phase) t_fold_sorted, inds = ravest.model.fold_time_series(self.time, P, Tc) tsmooth_fold_sorted, smooth_inds = ravest.model.fold_time_series(tsmooth, P, Tc) # Calculate RV contribution from this planet only planet_params = {par: params[f"{par}_{planet_letter}"] for par in self.parameterisation.pars} planet = ravest.model.Planet(planet_letter, self.parameterisation, planet_params) planet_rv_tsmooth = planet.radial_velocity(tsmooth) planet_rv_data = planet.radial_velocity(self.time) planet_rv_sorted = planet_rv_tsmooth[smooth_inds] # Calculate all other contributions (other planets + trend + gamma) at data times other_rv = np.zeros(len(self.time)) for other_letter in self.planet_letters: if other_letter != planet_letter: other_params = {par: params[f"{par}_{other_letter}"] for par in self.parameterisation.pars} other_planet = ravest.model.Planet(other_letter, self.parameterisation, other_params) other_rv += other_planet.radial_velocity(self.time) # Add trend (no gamma offset - that's per-instrument) trend_params = {"gd": params["gd"], "gdd": params["gdd"]} trend = ravest.model.Trend(params=trend_params, t0=self.t0) other_rv += trend.radial_velocity(self.time) # Add per-instrument gamma offsets for inst in self.unique_instruments: mask = (self.instrument == inst) other_rv[mask] += params[f"g_{inst}"] # Calculate data with other components subtracted data_minus_others = self.vel - other_rv # Sort the data according to phase folding data_minus_others_sorted = data_minus_others[inds] verr_sorted = self.velerr[inds] velerr_with_jit_sorted = velerr_with_jit[inds] instrument_sorted = self.instrument[inds] # Calculate residuals (data - model for this planet) residuals = data_minus_others - planet_rv_data residuals_sorted = residuals[inds] # Create figure with subplots (main plot + residuals) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 5), gridspec_kw={'height_ratios': [3, 1], 'hspace': 0}) # Set up per-instrument colours colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)} # Main phase plot - plot each instrument separately for inst in self.unique_instruments: mask = (instrument_sorted == inst) ax1.errorbar(t_fold_sorted[mask], data_minus_others_sorted[mask], yerr=verr_sorted[mask], marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst], markersize=8, zorder=4, label=inst) # Jitter extension (faded, no label) ax1.errorbar(t_fold_sorted[mask], data_minus_others_sorted[mask], yerr=velerr_with_jit_sorted[mask], marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3) # Plot phase-folded model for this planet ax1.plot(tsmooth_fold_sorted, planet_rv_sorted, label="Model", color="black", zorder=2) ax1.set_xlim(-0.5, 0.5) ax1.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25 if ylabel_main: ax1.set_ylabel(ylabel_main) ax1.legend(loc="upper right") if title: ax1.set_title(title) ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in') # Remove x-axis labels from top plot ax1.tick_params(axis='y', direction='in') # Set y-axis ticks automatically based on phase data range ax1.yaxis.set_major_locator(AutoLocator()) ax1.yaxis.set_minor_locator(AutoMinorLocator()) ax1.tick_params(axis='y', which='minor', direction='in', length=3) # Annotate with planet info K_value = params[f"K_{planet_letter}"] P_label = param_key_to_latex(f"P_{planet_letter}") K_label = param_key_to_latex(f"K_{planet_letter}") s = f"Planet {planet_letter}\n{P_label}={P:.2f} d\n{K_label}={K_value:.2f} m/s" ax1.annotate(s, xy=(0, 1), xycoords="axes fraction", xytext=(+0.5, -0.5), textcoords="offset fontsize", va="top") # Residuals plot (phase-folded) - plot each instrument separately for inst in self.unique_instruments: mask = (instrument_sorted == inst) ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=verr_sorted[mask], marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst], markersize=8, zorder=4) ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=velerr_with_jit_sorted[mask], marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3) ax2.axhline(0, color="k", linestyle="--", zorder=2) ax2.set_xlim(-0.5, 0.5) ax2.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25 # Set symmetric y-limits for residuals max_abs_residual = np.max(np.abs(residuals_sorted + velerr_with_jit_sorted)) ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1) if xlabel: ax2.set_xlabel(xlabel) if ylabel_residuals: ax2.set_ylabel(ylabel_residuals) ax2.tick_params(axis='x', direction='in') ax2.tick_params(axis='y', direction='in') ax2.tick_params(axis='x', top=True, labeltop=False) # Add ticks on shared border # Set y-axis ticks automatically based on residuals range ax2.yaxis.set_major_locator(AutoLocator()) ax2.yaxis.set_minor_locator(AutoMinorLocator()) ax2.tick_params(axis='y', which='minor', direction='in', length=3) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def plot_posterior_rv(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, show_CI: bool = True, title: str | None = "Posterior RV", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "posterior_rv.png", dpi: int = 100) -> None: """Plot the posterior RV model with uncertainty bands from MCMC samples. Calculates RV model predictions for each MCMC sample, then plots the median with optional 68% CI (16th-84th percentile) uncertainty bands. Shows both the full model and residuals vs data. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) show_CI : bool, optional Show 68.3% credible interval band (default: True) title : str or None, optional Title for the main RV plot (default: "Posterior RV"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot to path `fname` (default: False) fname : str, optional The path to save the plot to (default: "posterior_rv.png") dpi : int, optional The dpi to save the image at (default: 100) """ # Create smooth time curve for plotting (same as _plot_rv helper) _tmin, _tmax = self.time.min(), self.time.max() _trange = _tmax - _tmin tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000) # Calculate posterior RV predictions (planets + trend, no gamma) rv_all_planets_trend_matrix_smooth = self.calculate_rv_total_from_samples(times=tsmooth, discard_start=discard_start, discard_end=discard_end, thin=thin) rv_all_planets_trend_matrix_obs = self.calculate_rv_total_from_samples(times=self.time, discard_start=discard_start, discard_end=discard_end, thin=thin) # Calculate percentiles rv_percentiles_smooth = np.percentile(rv_all_planets_trend_matrix_smooth, [15.85, 50, 84.15], axis=0) rv_percentiles_obs = np.percentile(rv_all_planets_trend_matrix_obs, [15.85, 50, 84.15], axis=0) # Get samples dict for per-instrument gamma and jitter samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin) # Subtract per-instrument gamma offsets from data (compare gamma-corrected data to physical model) vel_corrected = self.vel.copy() for inst in self.unique_instruments: mask = (self.instrument == inst) g_key = f"g_{inst}" if g_key in samples_dict: g_median = np.median(samples_dict[g_key]) else: g_median = self.fixed_params_values_dict[g_key] vel_corrected[mask] -= g_median # Calculate residuals using median model at data times residuals = vel_corrected - rv_percentiles_obs[1] # Calculate per-instrument jitter for error bars velerr_with_jit = np.zeros_like(self.velerr) for inst in self.unique_instruments: mask = (self.instrument == inst) jit_key = f"jit_{inst}" if jit_key in samples_dict: jit_median = np.median(samples_dict[jit_key]) else: jit_median = self.fixed_params_values_dict[jit_key] velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit_median**2) # Create figure with subplots fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 5), gridspec_kw={'height_ratios': [3, 1], 'hspace': 0}) # Main RV plot - plot data by instrument with different colours colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)} for inst in self.unique_instruments: mask = (self.instrument == inst) ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=self.velerr[mask], marker=".", color=inst_colors[inst], ecolor=inst_colors[inst], linestyle="None", markersize=8, zorder=4, label=inst) # Jitter extension (faded, no label) ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=velerr_with_jit[mask], marker="None", ecolor=inst_colors[inst], linestyle="None", alpha=0.5, zorder=3) # Plot median model and uncertainty ax1.plot(tsmooth, rv_percentiles_smooth[1], label="Model", color="black", zorder=2) if show_CI: ax1.fill_between(tsmooth, rv_percentiles_smooth[0], rv_percentiles_smooth[2], color="tab:gray", alpha=0.3, edgecolor="none", label="68.3% CI", zorder=1) ax1.set_xlim(tsmooth[0], tsmooth[-1]) if ylabel_main: ax1.set_ylabel(ylabel_main) if title: ax1.set_title(title) ax1.legend(loc="upper right") ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in') ax1.tick_params(axis='y', direction='in') # Set y-axis ticks automatically based on data range ax1.yaxis.set_major_locator(AutoLocator()) ax1.yaxis.set_minor_locator(AutoMinorLocator()) ax1.tick_params(axis='y', which='minor', direction='in', length=3) # Residuals plot for inst in self.unique_instruments: mask = (self.instrument == inst) ax2.errorbar(self.time[mask], residuals[mask], yerr=self.velerr[mask], marker=".", color=inst_colors[inst], ecolor=inst_colors[inst], linestyle="None", markersize=8, zorder=4) ax2.errorbar(self.time[mask], residuals[mask], yerr=velerr_with_jit[mask], marker="None", ecolor=inst_colors[inst], linestyle="None", alpha=0.5, zorder=3) ax2.axhline(0, color="k", linestyle="--", zorder=2) ax2.set_xlim(tsmooth[0], tsmooth[-1]) # Set symmetric y-limits for residuals plot, so 0 is in centre max_abs_residual = np.max(np.abs(residuals + velerr_with_jit)) ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1) if xlabel: ax2.set_xlabel(xlabel) if ylabel_residuals: ax2.set_ylabel(ylabel_residuals) ax2.tick_params(axis='x', direction='in') ax2.tick_params(axis='y', direction='in') ax2.tick_params(axis='x', top=True, labeltop=False) # Set y-axis ticks automatically based on residuals range ax2.yaxis.set_major_locator(AutoLocator()) ax2.yaxis.set_minor_locator(AutoMinorLocator()) ax2.tick_params(axis='y', which='minor', direction='in', length=3) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def plot_posterior_phase(self, planet_letter: str, discard_start: int = 0, discard_end: int = 0, thin: int = 1, show_CI: bool = True, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "posterior_phase.png", dpi: int = 100) -> None: """Plot phase-folded RV model with uncertainty bands from MCMC samples. Shows the phase-folded planetary signal with uncertainty bands calculated from MCMC samples. Removes contributions from trends and other planets. Parameters ---------- planet_letter : str Letter identifying the planet to plot (e.g., 'b', 'c', 'd') discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) show_CI : bool, optional Show 68.3% credible interval band (default: True) title : str or None, optional Title for the main phase plot (default: "Posterior Phase Plot - Planet {planet_letter}"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot to path `fname` (default: False) fname : str, optional The path to save the plot to (default: "posterior_phase.png") dpi : int, optional The dpi to save the image at (default: 100) """ # Get period (handle both free and fixed cases) samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin) # Combine with fixed parameters params = samples_dict | self.fixed_params_values_dict # Create smooth time array (same approach as _plot_rv) _tmin, _tmax = self.time.min(), self.time.max() _trange = _tmax - _tmin tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000) # Calculate per-instrument jitter for error bars velerr_with_jit = np.zeros_like(self.velerr) for inst in self.unique_instruments: mask = (self.instrument == inst) jit_key = f"jit_{inst}" if jit_key in samples_dict: jit_med = np.median(samples_dict[jit_key]) else: jit_med = self.fixed_params_values_dict[jit_key] velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit_med**2) # Get period value _P = params[f'P_{planet_letter}'] # Get (or calculate) Tc for this planet for folding around if "Tc" in self.parameterisation.pars: _Tc = params[f"Tc_{planet_letter}"] else: # Fall back to default parameterisation conversion (as it has P, e, w and Tp, so we can definitely get Tc) planet_params = {par: params[f"{par}_{planet_letter}"] for par in self.parameterisation.pars} default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params) _Tc = self.parameterisation.convert_tp_to_tc(default_params["Tp"], _P, default_params["e"], default_params["w"]) # just for the folding, take the median value of the P and Tc samples Tc_med = np.median(_Tc) P_med = np.median(_P) # Phase fold both data and smooth times t_fold, inds = ravest.model.fold_time_series(self.time, P_med, Tc_med) tsmooth_fold_sorted, smooth_inds = ravest.model.fold_time_series(tsmooth, P_med, Tc_med) # Calculate RV components from MCMC samples (matrix of n_samples x n_obs, or n_samples x n_smooth) rv_planet_data = self.calculate_rv_planet_from_samples(planet_letter, self.time, discard_start, discard_end, thin) rv_planet_smooth = self.calculate_rv_planet_from_samples(planet_letter, tsmooth, discard_start, discard_end, thin) rv_trend_data = self.calculate_rv_trend_from_samples(self.time, discard_start, discard_end, thin) # Calculate RV contributions from all OTHER planets (not the target planet) rv_other_planets_data = np.zeros_like(rv_trend_data) for other_letter in self.planet_letters: if other_letter != planet_letter: rv_other_planet = self.calculate_rv_planet_from_samples(other_letter, self.time, discard_start, discard_end, thin) rv_other_planets_data += rv_other_planet # Combine all non-target contributions (other planets + trend) rv_others_total_data = rv_trend_data + rv_other_planets_data # Add per-instrument gamma offsets (using median values) for inst in self.unique_instruments: mask = (self.instrument == inst) g_key = f"g_{inst}" if g_key in samples_dict: g_median = np.median(samples_dict[g_key]) else: g_median = self.fixed_params_values_dict[g_key] rv_others_total_data[:, mask] += g_median # Calculate percentiles across these matrices of RVs rv_planet_data_percs = np.percentile(rv_planet_data, [15.85, 50, 84.15], axis=0) rv_planet_smooth_percs = np.percentile(rv_planet_smooth, [15.85, 50, 84.15], axis=0) rv_others_total_percs = np.percentile(rv_others_total_data, [15.85, 50, 84.15], axis=0) # Remove all other contributions from observed data (using median of combined other contributions) data_minus_others = self.vel - rv_others_total_percs[1] # That gives us the component of the data just due to this planet. So now we can see how our modelled planet compares residuals = data_minus_others - rv_planet_data_percs[1] # Create figure with subplots (main plot + residuals) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 5), sharex=True, gridspec_kw={'height_ratios': [3, 1], 'hspace': 0}) # Set up per-instrument colours colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)} instrument_sorted = self.instrument[inds] # Main phase plot - plot data with other contributions removed, sorted by phase for inst in self.unique_instruments: mask = (instrument_sorted == inst) ax1.errorbar(t_fold[mask], data_minus_others[inds][mask], yerr=self.velerr[inds][mask], marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst], markersize=8, zorder=4, label=inst) # Jitter extension (faded, no label) ax1.errorbar(t_fold[mask], data_minus_others[inds][mask], yerr=velerr_with_jit[inds][mask], marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3) # Plot planet model with uncertainty, sorted by phase ax1.plot(tsmooth_fold_sorted, rv_planet_smooth_percs[1][smooth_inds], linestyle="-", color="black", zorder=2, label="Model") if show_CI: ax1.fill_between(tsmooth_fold_sorted, rv_planet_smooth_percs[0][smooth_inds], rv_planet_smooth_percs[2][smooth_inds], color="tab:gray", alpha=0.3, edgecolor="none", label="68.3% CI", zorder=1) ax1.set_xlim(-0.5, 0.5) ax1.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25 if ylabel_main: ax1.set_ylabel(ylabel_main) if title is None: ax1.set_title(f"Posterior Phase Plot - Planet {planet_letter}") elif title: ax1.set_title(title) ax1.legend(loc="upper right") ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in') ax1.tick_params(axis='y', direction='in') # Set y-axis ticks automatically based on phase data range ax1.yaxis.set_major_locator(AutoLocator()) ax1.yaxis.set_minor_locator(AutoMinorLocator()) ax1.tick_params(axis='y', which='minor', direction='in', length=3) # Residuals plot (phase-folded) - plot each instrument separately for inst in self.unique_instruments: mask = (instrument_sorted == inst) ax2.errorbar(t_fold[mask], residuals[inds][mask], yerr=self.velerr[inds][mask], marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst], markersize=8, zorder=4) ax2.errorbar(t_fold[mask], residuals[inds][mask], yerr=velerr_with_jit[inds][mask], marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3) ax2.axhline(0, color="k", linestyle="--", zorder=2) ax2.set_xlim(-0.5, 0.5) ax2.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25 # Set symmetric y-limits for residuals max_abs_residual = np.max(np.abs(residuals[inds] + velerr_with_jit[inds])) ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1) if xlabel: ax2.set_xlabel(xlabel) if ylabel_residuals: ax2.set_ylabel(ylabel_residuals) ax2.tick_params(axis='x', direction='in') ax2.tick_params(axis='y', direction='in') ax2.tick_params(axis='x', top=True, labeltop=False) # Set y-axis ticks automatically based on residuals range ax2.yaxis.set_major_locator(AutoLocator()) ax2.yaxis.set_minor_locator(AutoMinorLocator()) ax2.tick_params(axis='y', which='minor', direction='in', length=3) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def calculate_rv_planet_from_samples(self, planet_letter: str, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = False) -> np.ndarray: """Calculate planetary RV for each MCMC sample. This calculates RV(params_i) for each MCMC sample i, preserving parameter correlations. This differs from using median parameters which may not represent actual samples from the posterior. Parameters ---------- planet_letter : str Planet letter (e.g., 'b', 'c') times : np.ndarray Time points to calculate RV at discard_start : int, optional Discard first N steps (default: 0) discard_end : int, optional Discard last N steps (default: 0) thin : int, optional Use every Nth sample (default: 1) progress : bool, optional Show progress bar (default: False) Returns ------- np.ndarray Shape (n_samples, len(times)) - RV for each sample """ samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) planet_rvs = np.zeros((len(samples), len(times))) iterator = tqdm(enumerate(samples), total=len(samples), disable=not progress, desc=f"Calculating planet {planet_letter} RV from samples") for i, row in iterator: # Build complete params dict for this sample params = self.build_params_dict(row) # Use custom method planet_rvs[i, :] = self.calculate_rv_planet_custom(planet_letter, times, params) return planet_rvs
[docs] def calculate_rv_trend_from_samples(self, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = False) -> np.ndarray: """Calculate trend RV for each MCMC sample. This calculates RV_trend(params_i) for each MCMC sample i, preserving parameter correlations. This differs from using median parameters which may not represent actual samples from the posterior. Parameters ---------- times : np.ndarray Time points to calculate RV at discard_start : int, optional Discard first N steps (default: 0) discard_end : int, optional Discard last N steps (default: 0) thin : int, optional Use every Nth sample (default: 1) progress : bool, optional Show progress bar (default: False) Returns ------- np.ndarray Shape (n_samples, len(times)) - Trend RV for each sample """ samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) trend_rvs = np.zeros((len(samples), len(times))) iterator = tqdm(enumerate(samples), total=len(samples), disable=not progress, desc="Calculating trend RV from samples") for i, row in iterator: # Build complete params dict for this sample params = self.build_params_dict(row) # Use custom method trend_rvs[i, :] = self.calculate_rv_trend_custom(times, params) return trend_rvs
[docs] def calculate_rv_total_from_samples(self, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = False) -> np.ndarray: """Calculate total RV (planets + trend) for each MCMC sample. This calculates RV_total(params_i) for each MCMC sample i, preserving parameter correlations. This differs from using median parameters which may not represent actual samples from the posterior. Parameters ---------- times : np.ndarray Time points to calculate RV at discard_start : int, optional Discard first N steps (default: 0) discard_end : int, optional Discard last N steps (default: 0) thin : int, optional Use every Nth sample (default: 1) progress : bool, optional Show progress bar (default: False) Returns ------- np.ndarray Shape (n_samples, len(times)) - Total RV for each sample """ # Get trend RV for all samples total_rvs = self.calculate_rv_trend_from_samples(times, discard_start, discard_end, thin, progress) # Add each planet's RV for planet_letter in self.planet_letters: planet_rvs = self.calculate_rv_planet_from_samples(planet_letter, times, discard_start, discard_end, thin, progress) total_rvs += planet_rvs return total_rvs
[docs] def calculate_rv_planet_custom(self, planet_letter: str, times: np.ndarray, params: dict[str, float]) -> np.ndarray: """Calculate planetary RV for a single set of custom parameters. Useful for calculating RV with specific parameter values (e.g., best lnprob sample, median parameters, or experimental values). Parameters ---------- planet_letter : str Planet letter (e.g., 'b', 'c') times : np.ndarray Time points to calculate RV at params : dict[str, float] Complete parameter dictionary (both free and fixed parameters). Can be created using build_params_dict() or manually constructed. Returns ------- np.ndarray RV values at the requested times Examples -------- >>> # Using MAP result >>> map_result = fitter.find_map_estimate() >>> params = fitter.build_params_dict(map_result.x) >>> rv = fitter.calculate_rv_planet_custom('b', times, params) >>> >>> # Using best lnprob sample >>> best_params = fitter.get_sample_with_best_lnprob(discard_start=1000) >>> params = fitter.build_params_dict(best_params) >>> rv = fitter.calculate_rv_planet_custom('b', times, params) """ # Extract planet parameters planet_params = {} for par in self.parameterisation.pars: key = f"{par}_{planet_letter}" planet_params[par] = params[key] # Calculate planet RV planet = ravest.model.Planet(planet_letter, self.parameterisation, planet_params) return planet.radial_velocity(times)
[docs] def calculate_rv_trend_custom(self, times: np.ndarray, params: dict[str, float]) -> np.ndarray: """Calculate trend RV for a single set of custom parameters. Useful for calculating RV with specific parameter values (e.g., best lnprob sample, median parameters, or experimental values). Parameters ---------- times : np.ndarray Time points to calculate RV at params : dict[str, float] Complete parameter dictionary (both free and fixed parameters). Can be created using build_params_dict() or manually constructed. Returns ------- np.ndarray Trend RV values at the requested times Examples -------- >>> # Using MAP result >>> map_result = fitter.find_map_estimate() >>> params = fitter.build_params_dict(map_result.x) >>> trend_rv = fitter.calculate_rv_trend_custom(times, params) """ # Calculate trend RV (no gamma offset - that's per-instrument) trend = ravest.model.Trend(params={"gd": params["gd"], "gdd": params["gdd"]}, t0=self.t0) return trend.radial_velocity(times)
[docs] def calculate_rv_total_custom(self, times: np.ndarray, params: dict[str, float]) -> np.ndarray: """Calculate total RV (trend + all planets) for a single set of custom parameters. Useful for calculating RV with specific parameter values (e.g., best lnprob sample, median parameters, or experimental values). Parameters ---------- times : np.ndarray Time points to calculate RV at params : dict[str, float] Complete parameter dictionary (both free and fixed parameters). Can be created using build_params_dict() or manually constructed. Returns ------- np.ndarray Total RV values (trend + all planets) at the requested times Examples -------- >>> # Using MAP result >>> map_result = fitter.find_map_estimate() >>> params = fitter.build_params_dict(map_result.x) >>> total_rv = fitter.calculate_rv_total_custom(times, params) >>> >>> # Using median parameters >>> samples_df = fitter.get_samples_df(discard_start=1000) >>> median_values = samples_df.median().to_dict() >>> params = fitter.build_params_dict(median_values) >>> total_rv = fitter.calculate_rv_total_custom(times, params) """ # Calculate trend total_rv = self.calculate_rv_trend_custom(times, params) # Add each planet for planet_letter in self.planet_letters: planet_rv = self.calculate_rv_planet_custom(planet_letter, times, params) total_rv += planet_rv return total_rv
[docs] def plot_MAP_rv(self, map_result: scipy.optimize.OptimizeResult, title: str | None = "MAP RV", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "MAP_rv.png", dpi: int = 100) -> None: """Plot radial velocity data and model using MAP parameter estimates. Parameters ---------- map_result : scipy.optimize.OptimizeResult Result from find_map_estimate() containing the MAP parameters title : str or None, optional Plot title (default: "MAP RV"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "MAP_rv.png") dpi : int, optional Resolution for saving (default: 100) """ # Get MAP parameter values from the optimization result map_params = dict(zip(self.free_params_names, map_result.x)) # Combine with fixed parameters all_params = self.fixed_params_values_dict | map_params # Use helper function to create the plot self._plot_rv(all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] def plot_MAP_phase(self, planet_letter: str, map_result: scipy.optimize.OptimizeResult, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "MAP_phase.png", dpi: int = 100) -> None: """Plot phase-folded radial velocity data and model using MAP parameter estimates. Parameters ---------- planet_letter : str Letter identifying the planet to plot (e.g., 'b', 'c', 'd') map_result : scipy.optimize.OptimizeResult Result from find_map_estimate() containing the MAP parameters title : str or None, optional Plot title (default: f"MAP Phase Plot - Planet {planet_letter}"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "MAP_phase.png") dpi : int, optional Resolution for saving (default: 100) """ # Get MAP parameter values from the optimization result map_params = dict(zip(self.free_params_names, map_result.x)) # Combine with fixed parameters all_params = self.fixed_params_values_dict | map_params # Set default title if not provided if title is None: title = f"MAP Phase Plot - Planet {planet_letter}" # Use helper function to create the plot self._plot_phase(planet_letter, all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] def plot_custom_rv(self, params: dict, title: str | None = "Custom RV Plot", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "custom_rv.png", dpi: int = 100) -> None: """Plot radial velocity data and model using custom parameter values. Allows plotting with arbitrary parameter values for exploring parameter space or comparing theoretical models. Parameters ---------- params : dict Dictionary of parameter values to use for plotting. Keys should match parameter names, values should be floats. Must include all required parameters for the current parameterisation. title : str or None, optional Plot title (default: "Custom RV Plot"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "custom_rv.png") dpi : int, optional Resolution for saving (default: 100) Examples -------- >>> # Plot with custom values (must include all required parameters) >>> fitter.plot_custom_rv({"P_b": 4.25, "K_b": 55.0, "e_b": 0.1, ... "w_b": 1.57, "Tc_b": 2456325.5, ... "g_HARPS": -10.2, "jit_HARPS": 2.0, ... "gd": 0.0, "gdd": 0.0}) """ # Validate that all required parameters are present expected_params = set(self.free_params_names + list(self.fixed_params_names)) provided_params = set(params.keys()) missing_params = expected_params - provided_params if missing_params: raise ValueError(f"Missing required parameters: {missing_params}") # Use helper function to create the plot self._plot_rv(params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] def plot_custom_phase(self, planet_letter: str, params: dict, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "custom_phase.png", dpi: int = 100) -> None: """Plot phase-folded radial velocity data and model using custom parameter values. Allows plotting phase-folded data with arbitrary parameter values for exploring parameter space or comparing theoretical models. Parameters ---------- planet_letter : str Letter identifying the planet to plot (e.g., 'b', 'c', 'd') params : dict Dictionary of parameter values to use for plotting. Keys should match parameter names, values should be floats. Must include all required parameters for the current parameterisation. title : str or None, optional Plot title (default: f"Custom Phase Plot - Planet {planet_letter}"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "custom_phase.png") dpi : int, optional Resolution for saving (default: 100) Examples -------- >>> # Plot phase curve with custom values >>> fitter.plot_custom_phase("b", {"P_b": 4.25, "K_b": 55.0, "e_b": 0.1, ... "w_b": 1.57, "Tc_b": 2456325.5, ... "g_HARPS": -10.2, "jit_HARPS": 2.0, ... "gd": 0.0, "gdd": 0.0}) """ # Validate that all required parameters are present expected_params = set(self.free_params_names + list(self.fixed_params_names)) provided_params = set(params.keys()) missing_params = expected_params - provided_params if missing_params: raise ValueError(f"Missing required parameters: {missing_params}") # Set default title if not provided if title is None: title = f"Custom Phase Plot - Planet {planet_letter}" # Use helper function to create the plot self._plot_phase(planet_letter, params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] def plot_best_sample_rv(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = "Best Sample RV Plot", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "best_sample_rv.png", dpi: int = 100) -> None: """Plot radial velocity data and model using parameter values from the MCMC sample with highest log probability. This is useful for comparing with plot_MAP_rv() to diagnose potential issues with MAP convergence or MCMC mixing. The two plots should be very similar if both MAP and MCMC are working correctly. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) title : str or None, optional Title for the main RV plot (default: "Best Sample RV Plot"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "best_sample_rv.png") dpi : int, optional Resolution for saving (default: 100) """ # Get parameter values from best sample best_sample_params = self.get_sample_with_best_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin) # Combine with fixed parameters all_params = self.fixed_params_values_dict | best_sample_params # Use helper function to create the plot self._plot_rv(all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] def plot_best_sample_phase(self, planet_letter: str, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "best_sample_phase.png", dpi: int = 100) -> None: """Plot phase-folded radial velocity data and model using parameter values from the MCMC sample with highest log probability. This is useful for comparing with plot_MAP_phase() to diagnose potential issues with MAP convergence or MCMC mixing. The two plots should be very similar if both MAP and MCMC are working correctly. Parameters ---------- planet_letter : str Letter identifying the planet to plot (e.g., 'b', 'c', 'd') discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) title : str or None, optional Title for the main phase plot (default: "Best Sample Phase Plot - Planet {planet_letter}"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "best_sample_phase.png") dpi : int, optional Resolution for saving (default: 100) """ # Get parameter values from best sample best_sample_params = self.get_sample_with_best_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin) # Combine with fixed parameters all_params = self.fixed_params_values_dict | best_sample_params # Set default title if not provided if title is None: title = f"Best Sample Phase Plot - Planet {planet_letter}" # Use helper function to create the plot self._plot_phase(planet_letter, all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] class LogPosterior: """Log posterior probability for MCMC sampling. Combines log likelihood and log prior for Bayesian parameter estimation. """ def __init__( self, planet_letters: list[str], parameterisation: Parameterisation, priors: dict[str, Callable[[float], float]], fixed_params: dict[str, float], free_params_names: list[str], time: np.ndarray, vel: np.ndarray, velerr: np.ndarray, instrument: np.ndarray, unique_instruments: np.ndarray, t0: float, ) -> None: """Initialize the LogPosterior object. Parameters ---------- planet_letters : list[str] List of single-character planet identifiers. parameterisation : Parameterisation The orbital parameterisation to use. priors : dict[str, Callable[[float], float]] Dictionary mapping parameter names to their prior probability functions. fixed_params : dict[str, float] Dictionary of fixed parameter values. free_params_names : list[str] List of free parameter names to sample. time : np.ndarray Time of each observation [days]. vel : np.ndarray Radial velocity at each time [m/s]. velerr : np.ndarray Uncertainty on the radial velocity at each time [m/s]. instrument : np.ndarray Instrument name for each observation. unique_instruments : np.ndarray Unique instrument names in the data. t0 : float Reference time for the trend [days]. """ self.planet_letters = planet_letters self.parameterisation = parameterisation self.priors = priors self.fixed_params = fixed_params self.free_params_names = free_params_names self.time = time self.vel = vel self.velerr = velerr self.instrument = instrument self.unique_instruments = unique_instruments self.t0 = t0 # Create log-likelihood and log-prior objects for later self.log_likelihood = LogLikelihood( time=self.time, vel=self.vel, velerr=self.velerr, instrument=self.instrument, unique_instruments=self.unique_instruments, t0=self.t0, planet_letters=self.planet_letters, parameterisation=self.parameterisation, ) self.log_prior = LogPrior(self.priors)
[docs] def _convert_params_for_prior_evaluation(self, free_params_dict: dict[str, float]) -> Dict[str, float]: """Convert free parameters for prior evaluation if needed. Parameters ---------- free_params_dict : dict Free parameters in current parameterisation Returns ------- dict Parameters with names/values converted for prior evaluation """ # Three cases: # Case 1: User is fitting in transformed parameterisation, but priors are in same transformed parameterisation # Case 2: User is fitting in default parameterisation, and priors are also in default parameterisation # Case 3: User is fitting in transformed parameterisation, but priors are in default parameterisation # Simple detection: do prior keys match our current free parameter names? prior_keys = set(self.priors.keys()) free_param_keys = set(self.free_params_names) if prior_keys == free_param_keys: # No conversion needed (Cases 1 & 2) return free_params_dict else: # Conversion needed (Case 3) - convert to default parameterisation equivalents # Start with just the non-planetary parameters that match params_for_prior = {key: value for key, value in free_params_dict.items() if key in prior_keys} all_params = self.fixed_params | free_params_dict # Convert each planet's parameters for planet_letter in self.planet_letters: # Get current planet parameters planet_params = {par: all_params[f"{par}_{planet_letter}"] for par in self.parameterisation.pars} # Convert to default parameterisation default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params) # Add the converted parameter values for priors that need them for default_par, value in default_params.items(): default_param_key = f"{default_par}_{planet_letter}" if default_param_key in prior_keys: # Only add if we have a prior for it params_for_prior[default_param_key] = value return params_for_prior
[docs] def log_probability(self, free_params_dict: Dict[str, float]) -> float: """Calculate log posterior probability for given free parameters. Parameters ---------- free_params_dict : Dict[str, float] Dictionary of free parameter values Returns ------- float Log posterior probability (log likelihood + log prior) """ # Fast fail for invalid jitter (before expensive prior/likelihood calculations) # We have to check jitter specifically because all other params will ultimately # get checked/raise Exceptions when they are used to calculate an RV. # Jitter doesn't directly contribute to calculated RV, so needs to be checked manually. _all_params_for_ll = self.fixed_params | free_params_dict for inst in self.unique_instruments: if _all_params_for_ll[f"jit_{inst}"] < 0: return -np.inf # Evaluate priors on the free parameters. If any parameters are outside priors # (i.e. priors are infinite), then fail fast by returning -inf early (before expensive likelihood calc). # We attempt to convert free parameters (if needed) for prior evaluation # This is for if the user is fitting in transformed parameterisation, # but defining their priors in the default parameterisation try: params_for_prior = self._convert_params_for_prior_evaluation(free_params_dict) lp = self.log_prior(params_for_prior) except ValueError: # Invalid parameter conversion (e.g., unphysical eccentricity) return -np.inf if not np.isfinite(lp): return -np.inf # Calculate log-likelihood with all parameters ll = self.log_likelihood(_all_params_for_ll) # Return combined log-posterior (log-likelihood + log-prior) logprob = ll + lp return logprob
[docs] def _negative_log_probability_for_MAP(self, free_params_vals: list[float]) -> float: """For MAP: run __call__ only passing in a list, not dict, of params. Because scipy.optimize.minimise only takes list of values, not a dict, we need to assign the values back to their corresponding keys, and pass that to __call__(). This does not check that the values are in the correct order, it is assumed. As we're dealing with dicts, this hopefully is the case. Parameters ---------- free_params_vals : list float values of the free parameters """ # Create dicts from the names and values # (Assumes the order of names matches the order of values) free_params_dict = dict(zip(self.free_params_names, free_params_vals)) # Calculate *negative* log_probability (MAP is backwards from MCMC) logprob = self.log_probability(free_params_dict) neg_logprob = -logprob # Handle -inf log_probability to prevent scipy RuntimeWarnings during optimisation # scipy's optimizer can't handle -inf values in arithmetic operations # (This does mean there is a non-zero chance we could end up returning a solution that doesn't satisfy the prior functions) if not np.isfinite(neg_logprob): return 1e30 # Very large finite number instead of +inf return neg_logprob
[docs] class LogLikelihood: """Log likelihood calculation for radial velocity data. Calculates log likelihood given RV model parameters and data. """ def __init__( self, time: np.ndarray, vel: np.ndarray, velerr: np.ndarray, instrument: np.ndarray, unique_instruments: np.ndarray, t0: float, planet_letters: list[str], parameterisation: Parameterisation, ) -> None: """Initialize the LogLikelihood object. Parameters ---------- time : np.ndarray Time of each observation [days]. vel : np.ndarray Radial velocity at each time [m/s]. velerr : np.ndarray Uncertainty on the radial velocity at each time [m/s]. instrument : np.ndarray Instrument name for each observation. unique_instruments : np.ndarray Unique instrument names in the data. t0 : float Reference time for the trend [days]. planet_letters : list[str] List of single-character planet identifiers. parameterisation : Parameterisation The orbital parameterisation to use. """ self.time = time self.vel = vel self.velerr = velerr self.instrument = instrument self.unique_instruments = unique_instruments self.t0 = t0 self.planet_letters = planet_letters self.parameterisation = parameterisation # Precompute a per-observation integer index array. # For each observation, store which instrument it came from as an integer: # e.g. unique_instruments = ["HARPS", "ESPRESSO"] # instrument = ["HARPS", "HARPS", "ESPRESSO", "HARPS", ...] # _instrument_indices = [0, 0, 1, 0, ...] # This lets us use numpy fancy indexing (array[integer_array]) to look up # per-instrument values for all observations in one vectorised operation, # instead of looping over instruments with boolean mask slices. _inst_to_idx = {inst: i for i, inst in enumerate(self.unique_instruments)} self._instrument_indices = np.array([_inst_to_idx[inst] for inst in self.instrument]) # Precompute parameter key strings for gamma and jitter lookups. # These strings (e.g. "g_HARPS", "jit_ESPRESSO") are constant for the lifetime # of this object — precomputing them avoids rebuilding f-strings on every call. self._gamma_keys = [f"g_{inst}" for inst in self.unique_instruments] self._jitter_keys = [f"jit_{inst}" for inst in self.unique_instruments] # Precompute log(2*pi) — it's a constant, no need to recalculate every time self._log_2pi = np.log(2 * np.pi) # Precompute velerr squared — constant (observed data doesn't change) so no need to recalculate every time self._velerr_sq = self.velerr ** 2
[docs] def __call__(self, params: Dict[str, float]) -> float: """Calculate log likelihood for given parameters. Parameters ---------- params : Dict[str, float] Dictionary of all parameter values Returns ------- float Log likelihood value """ rv_total = np.zeros(len(self.time)) # Step 1: Calculate RV contributions from each planet for letter in self.planet_letters: # get just the parameters for this planet (and strip the _letter suffix from the keys) _this_planet_params = { par: params[f"{par}_{letter}"] for par in self.parameterisation.pars } try: _this_planet = ravest.model.Planet(letter, self.parameterisation, _this_planet_params) _this_planet_rv = _this_planet.radial_velocity(self.time) except ValueError: # Planet.__init__ validates parameters and raises ValueError for invalid params return -np.inf # fail-fast: return -inf log-likelihood # add this planet's RV contribution to the total rv_total += _this_planet_rv # Step 2: Calculate and add the RV from the system Trend (no gamma offset) _trend_params = {"gd": params["gd"], "gdd": params["gdd"]} _this_trend = ravest.model.Trend(params=_trend_params, t0=self.t0) _rv_trend = _this_trend.radial_velocity(self.time) rv_total += _rv_trend # Step 3: Add per-instrument gamma offsets using vectorised fancy indexing. # Build a small array of gamma values, one per instrument (length K), then use # _instrument_indices to select the right gamma for each of the N observations. # This gives a length-N array in one numpy operation — no Python loop needed. gamma_per_instrument = np.array([params[k] for k in self._gamma_keys]) gamma_at_each_obs = gamma_per_instrument[self._instrument_indices] rv_total += gamma_at_each_obs # Step 4: Calculate log-likelihood with per-instrument jitter using vectorised fancy indexing. # Each instrument has its own jitter value. We need to pair each of the N observations # with its instrument's jitter. We do this in two steps: # 1. Build a small array of jitter values, one per instrument (length K) # 2. Use _instrument_indices to select the right jitter for each observation (length N) # The result is a full-length array ready for vectorised arithmetic — no Python loop needed. jitter_per_instrument = np.array([params[k] for k in self._jitter_keys]) jitter_at_each_obs = jitter_per_instrument[self._instrument_indices] velerr_jit_sq = self._velerr_sq + jitter_at_each_obs**2 penalty_term = self._log_2pi + np.log(velerr_jit_sq) residuals = rv_total - self.vel chi2 = residuals**2 / velerr_jit_sq ll = -0.5 * np.sum(chi2 + penalty_term) return ll
[docs] class LogPrior: """Log prior probability calculation. Evaluates log prior probabilities for model parameters. """ def __init__(self, priors: dict[str, Callable[[float], float]]) -> None: self.priors = priors
[docs] def __call__(self, params: Dict[str, float]) -> float: """Calculate log prior probability for given parameters. Parameters ---------- params : Dict[str, float] Dictionary of parameter values Returns ------- float Log prior probability """ log_prior_probability = 0 for param in params: # go into the `self.priors dict``, get the Prior object for this `param` # and call it with the value of said param, to get the prior probability log_prior_probability += self.priors[param](params[param]) return log_prior_probability
[docs] class GPFitter: """Gaussian Process fitter for exoplanet radial velocity data. Similar interface to Fitter class, but uses Gaussian Processes to model correlated noise in the data. The planetary RV model serves as the GP mean function. Supports MCMC sampling, MAP estimation, and various parameterisations. Handles multiple planets, trends, jitter parameters, and GP hyperparameters. """ def __init__(self, planet_letters: list[str], parameterisation: Parameterisation, gp_kernel: GPKernel) -> None: """Initialize the GPFitter object. Parameters ---------- planet_letters : list[str] List of single-character planet identifiers (e.g., ['b', 'c', 'd']). Used to distinguish parameters for different planets in the system. parameterisation : Parameterisation The orbital parameterisation to use for fitting. Defines which orbital elements are used as free/fixed parameters (e.g., 'Default', 'EccentricityWind'). gp_kernel : GPKernel The Gaussian Process kernel to use for modelling correlated noise in the data. """ self.planet_letters = planet_letters self.parameterisation = parameterisation self.gp_kernel = gp_kernel # Trigger numba JIT compilation before MCMC _dummy_M = np.linspace(0, 2 * np.pi, 10) _njit_kepler_rv(_dummy_M, 0.3, 10.0, 0.5) # Initialize parameter storage self._params: Dict[str, Parameter] = {} self._priors: Dict[str, Callable[[float], float]]= {} self._hyperparams: Dict[str, Parameter] = {} self._hyperpriors: Dict[str, Callable[[float], float]] = {}
[docs] def add_data( self, time: np.ndarray, vel: np.ndarray, velerr: np.ndarray, instrument: np.ndarray, t0: float, ) -> None: """Add the data to the GPFitter object. Parameters ---------- time : array-like Time of each observation [days] vel : array-like Radial velocity at each time [m/s] velerr : array-like Uncertainty on the radial velocity at each time [m/s] instrument : array-like Instrument name for each observation (e.g., "HARPS", "HIRES") t0 : float Reference time for the trend [days]. Recommended to set this as mean or median of input `time` array. """ if not (len(time) == len(vel) == len(velerr) == len(instrument)): raise ValueError( "Time, velocity, uncertainty, and instrument arrays must be the same length." ) self.time = np.ascontiguousarray(time) self.vel = np.ascontiguousarray(vel) self.velerr = np.ascontiguousarray(velerr) self.instrument = np.asarray(instrument) self.unique_instruments = np.unique(self.instrument) self.t0 = t0
@property def params(self) -> Dict[str, Parameter]: """Parameters dictionary. Set via: gpfitter.params = param_dict.""" return self._params @params.setter def params(self, new_params: Dict[str, Parameter]) -> None: """Set parameters with a dict, checking all required params are present. You can update all or some of the parameters at once, example: >>> gpfitter.params = {"g": Parameter(1.0, "m/s"), "gd": Parameter(0.1, "m/s/d")} # only update trend parameters >>> gpfitter.params = {"P_c": Parameter(5.0, "d"), "K_c": Parameter(3.5, "m/s")} # only update some of planet C parameters Parameters ---------- new_params : dict Dictionary of new parameter values to set. The keys of this dictionary should match the parameter names expected by the GPFitter object: all required parameters for the chosen parameterisation, with planet letters (not required for Trend or jitter parameters.) Raises ------ ValueError If any of the required parameters are missing or invalid. """ # Update the current _params dict with the new entries merged_params = dict(self._params) merged_params.update(new_params) # Validate the complete parameter set self._validate_complete_params(merged_params) # If validation passes, update the actual params self._params.update(new_params) # Update ndim to total number of free parameters (hyperparams are added to ndim when they're set later on) self.ndim = len(self.free_params_values) if self.ndim == 0: warnings.warn( "All parameters are fixed. MCMC methods require at least one " "free parameter or hyperparameter (fixed=False).", UserWarning, stacklevel=2 ) @property def hyperparams(self) -> Dict[str, Parameter]: """Hyperparameters dictionary. Set via: gpfitter.hyperparams = hyperparam_dict.""" return self._hyperparams @hyperparams.setter def hyperparams(self, new_hyperparams: Dict[str, Parameter]) -> None: """Set hyperparameters with validation.""" # Update the current _hyperparams dict with the new entries merged_hyperparams = dict(self._hyperparams) merged_hyperparams.update(new_hyperparams) # Validate using GPKernel (as the params required depend on the specific kernel) self.gp_kernel.validate_hyperparams(merged_hyperparams) # If validation passes, update the actual hyperparams self._hyperparams.update(new_hyperparams) # Update ndim to include hyperparameters (total free params + hyperparams) self.ndim = len(self.free_params_values) + len(self.free_hyperparams_values) if self.ndim == 0: warnings.warn( "All parameters and hyperparameters are fixed. MCMC methods require " "at least one free parameter or hyperparameter (fixed=False).", UserWarning, stacklevel=2 ) @property def priors(self) -> dict: """Priors dictionary. Set via: gpfitter.priors = prior_dict.""" return self._priors @priors.setter def priors(self, new_priors: dict[str, Callable[[float], float]]) -> None: """Set prior functions using a dict, checking all required priors are present. Priors must be provided for all free parameters. You can set all priors at once or update individual priors. Parameters ---------- new_priors : dict Dictionary of prior functions to set. Keys should be parameter names that match free parameters, values should be callable prior functions. Examples -------- >>> from ravest.prior import Uniform >>> gpfitter.priors = {"K_b": Uniform(0, 100), "P_b": Uniform(1, 30)} Raises ------ ValueError If any required priors are missing, unexpected prio rs are provided, or initial parameter values are outside prior bounds. """ self._set_priors_with_validation(new_priors) @property def hyperpriors(self) -> dict: """Hyperpriors dictionary. Set via: gpfitter.hyperpriors = hyperprior_dict.""" return self._hyperpriors @hyperpriors.setter def hyperpriors(self, new_hyperpriors: dict[str, Callable[[float], float]]) -> None: """Set hyperprior functions with validation.""" self._set_hyperpriors_with_validation(new_hyperpriors)
[docs] def _validate_complete_params(self, params: Dict[str, Parameter]) -> None: """Validate that params dict has required parameters, astrophysically valid values.""" # Build complete set of expected parameters expected_params = set() # Add planetary parameters for planet_letter in self.planet_letters: for par_name in self.parameterisation.pars: expected_params.add(f"{par_name}_{planet_letter}") # Add trend parameters (no gamma - that's per-instrument) expected_params.update(["gd", "gdd"]) # Add per-instrument gamma offset and jitter parameters for inst in self.unique_instruments: expected_params.add(f"g_{inst}") expected_params.add(f"jit_{inst}") # Validate same as Fitter provided_params = set(params.keys()) # Check for unexpected parameters unexpected_params = provided_params - expected_params if unexpected_params: # Give a specific hint if user is passing legacy single-instrument parameters legacy_params = unexpected_params & {"g", "jit"} if legacy_params: raise ValueError( f"Unexpected parameters: {unexpected_params}. " f"Single-instrument 'g' and 'jit' parameters are no longer supported. " f"Use per-instrument names instead, e.g. " f"{[f'g_{inst}' for inst in self.unique_instruments]} and " f"{[f'jit_{inst}' for inst in self.unique_instruments]}, " f"matching the instrument names passed to add_data()." ) raise ValueError( f"Unexpected parameters: {unexpected_params}. " f"Expected {len(expected_params)} parameters, got {len(provided_params)}" ) # Check for missing parameters missing_params = expected_params - provided_params if missing_params: raise ValueError( f"Missing required parameters: {missing_params}. " f"Expected {len(expected_params)} parameters, got {len(provided_params)}" ) # Validate astrophysical validity of all parameters params_values = {name: param.value for name, param in params.items()} self._validate_astrophysical_validity(params_values) # Validate parameter coupling constraints # i.e. if two parameters both need to be fixed or free together self._validate_parameter_coupling(params)
[docs] def _validate_astrophysical_validity(self, params_values: Dict[str, float]) -> None: """Validate that all parameter values are astrophysically valid.""" # First, check that ALL parameters are finite (not NaN or infinite) invalid_params = { name: value for name, value in params_values.items() if not np.isfinite(value) } if invalid_params: raise ValueError( "Invalid parameters detected: " + ", ".join(f"{k}={v}" for k, v in invalid_params.items()) ) # Validate planetary parameters for each planet for planet_letter in self.planet_letters: planet_params = {} for par_name in self.parameterisation.pars: key = f"{par_name}_{planet_letter}" planet_params[par_name] = params_values[key] # Validate this planet's parameters in current parameterisation self.parameterisation.validate_planetary_params(planet_params) # Validate trend parameters are finite real numbers (already checked above, but kept for clarity) for trend_param in ["gd", "gdd"]: trend_value = params_values[trend_param] if not np.isfinite(trend_value): raise ValueError(f"Invalid trend parameter {trend_param}: {trend_value} is not a finite real number") # Validate per-instrument parameters for inst in self.unique_instruments: # Gamma offset must be finite g_key = f"g_{inst}" if not np.isfinite(params_values[g_key]): raise ValueError(f"Invalid gamma offset {g_key}: {params_values[g_key]} is not finite") # Jitter must be >= 0 jit_key = f"jit_{inst}" if params_values[jit_key] < 0: raise ValueError(f"Invalid jitter {jit_key}: {params_values[jit_key]} < 0")
[docs] def _validate_parameter_coupling(self, params: Dict[str, Parameter]) -> None: """Validate parameter coupling constraints (e.g., secosw/sesinw must both be free or both fixed).""" for planet_letter in self.planet_letters: # Check secosw/sesinw coupling secosw_key = f"secosw_{planet_letter}" sesinw_key = f"sesinw_{planet_letter}" if secosw_key in params and sesinw_key in params: secosw_fixed = params[secosw_key].fixed sesinw_fixed = params[sesinw_key].fixed if secosw_fixed != sesinw_fixed: raise ValueError(f"Parameters {secosw_key} and {sesinw_key} must both be fixed or both be free") # Check ecosw/esinw coupling ecosw_key = f"ecosw_{planet_letter}" esinw_key = f"esinw_{planet_letter}" if ecosw_key in params and esinw_key in params: ecosw_fixed = params[ecosw_key].fixed esinw_fixed = params[esinw_key].fixed if ecosw_fixed != esinw_fixed: raise ValueError(f"Parameters {ecosw_key} and {esinw_key} must both be fixed or both be free")
[docs] def _set_priors_with_validation(self, new_priors: dict[str, Callable[[float], float]]) -> None: """Set priors with validation. Supports partial updates. Can be current or default parameterisation.""" # Create merged priors dict (in case user is only updating some priors, not all) merged_priors_dict = dict(self._priors) # get existing priors merged_priors_dict.update(new_priors) # overwrite with newer functions, if supplied provided_prior_param_names = set(merged_priors_dict.keys()) # There are two possibilities for priors: # 1. The prior has been given for the parameter, in the current parameterisation # (this can also include if the user is fitting in the default parameterisation) # 2. The prior has been given for the Default parameterisation's equivalent parameter instead # (e.g. e & w instead of secosw & sesinw, or tp instead of tc) # If not, then prior isn't given for either the Current or Default parameterisation, raise an Exception validated_priors = {} missing_priors = [] conflicts = [] # in the current parameterisation, which (free) parameters do we expect priors for? current_parameterisation_free_param_names = set(self.free_params_names) for free_param_name in current_parameterisation_free_param_names: if free_param_name in provided_prior_param_names: # Prior was provided for the param in the current parameterisation validated_priors[free_param_name] = merged_priors_dict[free_param_name] # Check if user ALSO provided equivalent default priors (conflict!) default_parameterisation_equivalent_free_param_names = self._get_default_parameterisation_equivalent_free_param_name(free_param_name) if default_parameterisation_equivalent_free_param_names: for equiv_param in default_parameterisation_equivalent_free_param_names: if equiv_param in provided_prior_param_names: conflicts.append((free_param_name, equiv_param)) else: # We haven't been provided the prior for the free parameter in the current parameterisation # So let's check if we were given the prior for the equivalent parameter in the default parameterisation instead default_parameterisation_equivalent_free_param_names = self._get_default_parameterisation_equivalent_free_param_name(free_param_name) # remember that one parameter in current parameterisation (e.g. secosw) might map to more than one equivalent in default parameterisation (e.g. both e & w) if default_parameterisation_equivalent_free_param_names and all(eq in provided_prior_param_names for eq in default_parameterisation_equivalent_free_param_names): # Found all required default equivalents for equiv in default_parameterisation_equivalent_free_param_names: validated_priors[equiv] = merged_priors_dict[equiv] else: # Missing prior for a free parameter in both the current parameterisation, and its equivalent in the default parameterisation if default_parameterisation_equivalent_free_param_names: missing_priors.append(f"{free_param_name} (or equivalent {default_parameterisation_equivalent_free_param_names})") else: missing_priors.append(free_param_name) # Check for conflicts after processing all parameters if conflicts: conflict_strs = [f"{current} vs {default}" for current, default in conflicts] raise ValueError(f"Conflicting priors provided for both current and default parameterisations: {', '.join(conflict_strs)}. Please provide priors for either the current parameterisation OR the equivalent default parameterisation, but not both.") if missing_priors: raise ValueError(f"Missing priors for parameters: {missing_priors}") # Check for unexpected priors - only allow priors that were validated above expected_prior_param_names = set(validated_priors.keys()) unexpected_prior_param_names = provided_prior_param_names - expected_prior_param_names if unexpected_prior_param_names: raise ValueError( f"Unexpected priors supplied for parameters: {unexpected_prior_param_names}. " f"Priors expected only for parameters: {expected_prior_param_names}" ) # Check parameter values work with priors self._check_params_values_against_priors(validated_priors, current_parameterisation_free_param_names) # Update the priors with the new values self._priors.update(new_priors)
[docs] def _get_default_parameterisation_equivalent_free_param_name(self, free_param: str) -> Optional[list[str]]: """Get the names of the default parameterisation equivalent parameter(s), for a single free parameter from the current parameterisation. Note this can be more than one: e.g. if you have secosw, this affects both e & w in the default parameterisation Whereas Tc just maps to Tp alone. Returns ------- list[str] | None - list[str]: equivalent parameter(s) in the default parameterisation - None: no mapping needed / no alternative priors to look for Raises ------ ValueError If `free_param` is not a recognised planet, instrument, or trend parameter. """ # No underscore (expected to be a system trend parameter) if '_' not in free_param: if free_param in ['gd', 'gdd']: # These are the same in all parameterisations return None else: raise ValueError(f"Unknown free parameter: {free_param}") # Contains underscore: Planetary or instrument parameters (with underscore before either planet letter or instrument name) # e.g. P_b, or Tc_c, or jit_HARPS else: base_param, suffix = free_param.split('_', 1) # split only on first underscore (some instrument names may have underscores too) # Planetary parameters: suffix is a planet letter if suffix in self.planet_letters: planet_letter = suffix if base_param in ['secosw', 'sesinw']: # Both secosw and sesinw map to e,w equivalents partner_param = 'sesinw' if base_param == 'secosw' else 'secosw' partner_key = f"{partner_param}_{planet_letter}" if partner_key in self.free_params_names: return [f"e_{planet_letter}", f"w_{planet_letter}"] elif base_param in ['ecosw', 'esinw']: # Both ecosw and esinw map to e,w equivalents partner_param = 'esinw' if base_param == 'ecosw' else 'ecosw' partner_key = f"{partner_param}_{planet_letter}" if partner_key in self.free_params_names: return [f"e_{planet_letter}", f"w_{planet_letter}"] elif base_param == 'Tc': # Tc can use Tp equivalent return [f"Tp_{planet_letter}"] elif base_param in ['P', 'K', 'e', 'w', 'Tp']: # These are default parameterisation parameters anyway return None else: # Suffix is a valid planet letter, but base parameter is unrecognised, so raise an error raise ValueError(f"Free parameter {free_param} has known planet letter {planet_letter} but unrecognised base parameter {base_param}.") # Instrument parameters: suffix is an instrument name elif suffix in self.unique_instruments: # The only instrument parameters are g and jit if base_param in ['g', 'jit']: # Per-instrument parameter (e.g., g_HARPS, jit_HIRES) # These are the same in all parameterisations return None else: raise ValueError(f"Free parameter {free_param} has known instrument name {suffix} but unrecognised base parameter {base_param} (expected 'g' or 'jit' only)") # Unknown: Suffix is present, but not a planet letter or instrument, so raise an error else: raise ValueError(f"Free parameter {free_param} has unrecognised suffix {suffix}, expected one of planet letters {self.planet_letters} or instrument names {self.unique_instruments}.")
[docs] def _check_params_values_against_priors(self, validated_priors: dict[str, Callable[[float], float]], current_free_param_names: list[str]) -> None: """Check parameter values against priors (including if Prior is for the Default parameterisation equivalent parameter).""" for prior_param_name, prior_function in validated_priors.items(): if prior_param_name in current_free_param_names: # This prior is in current parameterisation - check directly param_value = self.params[prior_param_name].value log_prior_probability = prior_function(param_value) if not np.isfinite(log_prior_probability): raise ValueError(f"Initial value {param_value} of parameter {prior_param_name} is invalid for prior {prior_function}.") else: # This prior is in default parameterisation - need to convert parameter value # Get the current parameter value and convert to default default_param_value = self._convert_single_param_to_default(prior_param_name) log_prior_probability = prior_function(default_param_value) if not np.isfinite(log_prior_probability): raise ValueError(f"Initial value {default_param_value} of parameter {prior_param_name} (in default parameterisation) is invalid for prior {prior_function}.")
[docs] def _convert_single_param_to_default(self, default_param_name: str) -> float: """Convert a single parameter from current to default parameterisation.""" # Extract planet letter if this is a planetary parameter if '_' in default_param_name: base_param, planet_letter = default_param_name.rsplit('_', 1) if planet_letter in self.planet_letters: # Get all current parameters for this planet (we need all five parameters to do a conversion) planet_params_dict = {} for par_name in self.parameterisation.pars: param_key = f"{par_name}_{planet_letter}" planet_params_dict[par_name] = self.params[param_key].value # Convert all the planetary parameters to the default parameterisation default_planet_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params_dict) # Return just the requested parameter in the default parameterisation return default_planet_params[base_param] # For non-planetary parameters (g, gd, gdd, jit), they're the same in all parameterisations if default_param_name in self.params: return self.params[default_param_name].value raise ValueError(f"Cannot convert parameter {default_param_name} to default parameterisation")
[docs] def _set_hyperpriors_with_validation(self, new_hyperpriors: dict[str, Callable[[float], float]]) -> None: """Set hyperpriors with validation.""" # Create merged hyperpriors dict merged_hyperpriors_dict = dict(self._hyperpriors) merged_hyperpriors_dict.update(new_hyperpriors) provided_hyperprior_param_names = set(merged_hyperpriors_dict.keys()) # Check that hyperpriors are provided for all free hyperparameters free_hyperparam_names = set(self.free_hyperparams_names) missing_hyperpriors = free_hyperparam_names - provided_hyperprior_param_names if missing_hyperpriors: raise ValueError(f"Missing hyperpriors for hyperparameters: {missing_hyperpriors}") # Check for unexpected hyperpriors unexpected_hyperprior_param_names = provided_hyperprior_param_names - free_hyperparam_names if unexpected_hyperprior_param_names: raise ValueError( f"Unexpected hyperpriors supplied for hyperparameters: {unexpected_hyperprior_param_names}. " f"Hyperpriors expected only for hyperparameters: {free_hyperparam_names}" ) # Check hyperparameter values work with hyperpriors self._check_hyperparams_values_against_hyperpriors(merged_hyperpriors_dict, free_hyperparam_names) # Update the hyperpriors with the new values self._hyperpriors.update(new_hyperpriors)
[docs] def _check_hyperparams_values_against_hyperpriors(self, validated_hyperpriors: dict[str, Callable[[float], float]], current_free_hyperparam_names: list[str]) -> None: """Check hyperparameter values against hyperpriors.""" for hyperprior_param_name, hyperprior_function in validated_hyperpriors.items(): hyperparam_value = self.hyperparams[hyperprior_param_name].value log_hyperprior_probability = hyperprior_function(hyperparam_value) if not np.isfinite(log_hyperprior_probability): raise ValueError(f"Initial value {hyperparam_value} of hyperparameter {hyperprior_param_name} is invalid for hyperprior {hyperprior_function}.")
@property def free_params_dict(self) -> Dict[str, Parameter]: """Free parameters as dict.""" free_pars = {} for par in self.params: if self.params[par].fixed is False: free_pars[par] = self.params[par] return free_pars @property def free_params_values(self) -> list[float]: """Values of free parameters as list.""" return [param.value for param in self.free_params_dict.values()] @property def free_params_names(self) -> list[str]: """Names of free parameters as list.""" return list(self.free_params_dict.keys()) @property def fixed_params_dict(self) -> Dict[str, Parameter]: """Fixed parameters as dict, mapping names to Parameter objects.""" fixed_pars = {} for par in self.params: if self.params[par].fixed is True: fixed_pars[par] = self.params[par] return fixed_pars @property def fixed_params_values(self) -> list[float]: """Values of fixed parameters, as list.""" return [param.value for param in self.fixed_params_dict.values()] @property def fixed_params_names(self) -> list[str]: """Names of fixed parameters, as list.""" return list(self.fixed_params_dict.keys()) @property def fixed_params_values_dict(self) -> Dict[str, float]: """Fixed parameters as dict mapping names to just the values.""" return dict(zip(self.fixed_params_names, self.fixed_params_values)) @property def free_hyperparams_dict(self) -> Dict[str, Parameter]: """Free hyperparameters as dict.""" free_hyperpars = {} for hyperpar in self.hyperparams: if self.hyperparams[hyperpar].fixed is False: free_hyperpars[hyperpar] = self.hyperparams[hyperpar] return free_hyperpars @property def free_hyperparams_values(self) -> list[float]: """Values of free hyperparameters as list.""" return [hyperparam.value for hyperparam in self.free_hyperparams_dict.values()] @property def free_hyperparams_names(self) -> list[str]: """Names of free hyperparameters as list.""" return list(self.free_hyperparams_dict.keys()) @property def fixed_hyperparams_dict(self) -> Dict[str, Parameter]: """Fixed hyperparameters as dict, mapping names to Parameter objects.""" fixed_hyperpars = {} for hyperpar in self.hyperparams: if self.hyperparams[hyperpar].fixed is True: fixed_hyperpars[hyperpar] = self.hyperparams[hyperpar] return fixed_hyperpars @property def fixed_hyperparams_values_dict(self) -> Dict[str, float]: """Fixed hyperparameters as dict mapping names to just the values.""" return dict(zip(self.fixed_hyperparams_names, self.fixed_hyperparams_values)) @property def fixed_hyperparams_names(self) -> list[str]: """Names of fixed hyperparameters, as list.""" return list(self.fixed_hyperparams_dict.keys()) @property def fixed_hyperparams_values(self) -> list[float]: """Values of fixed hyperparameters, as list.""" return [hyperparam.value for hyperparam in self.fixed_hyperparams_dict.values()]
[docs] def find_map_estimate(self, method: str = "Powell") -> scipy.optimize.OptimizeResult: """Find Maximum A Posteriori (MAP) estimate of parameters and hyperparameters. Parameters ---------- method : str, optional Optimization method to use (default: "Powell") Returns ------- scipy.optimize.OptimizeResult The optimization result containing the MAP estimate Raises ------ Warning If MAP optimization fails to converge """ # Initialize log-posterior object gp_lp = GPLogPosterior( self.planet_letters, self.parameterisation, self.gp_kernel, self.priors, self.hyperpriors, self.fixed_params_values_dict, self.fixed_hyperparams_values_dict, self.free_params_names, self.free_hyperparams_names, self.time, self.vel, self.velerr, self.t0, self.instrument, self.unique_instruments, ) # Combine free params and free hyperparams for initial guess initial_guess = self.free_params_values + self.free_hyperparams_values if len(initial_guess) == 0: raise ValueError( "Cannot run MAP optimisation: no free parameters or hyperparameters to optimise. " "At least one parameter or hyperparameter must be set as free (fixed=False) before calling find_map_estimate()." ) # Perform MAP optimization def negative_log_posterior(*args: float) -> float: return gp_lp._negative_log_probability_for_MAP(*args) map_results = minimize(negative_log_posterior, initial_guess, method=method) if map_results.success is False: print(map_results) warnings.warn("MAP did not succeed. Check the initial values of the parameters and hyperparameters, and the prior/hyperprior functions.") # Split results back into params and hyperparams n_params = len(self.free_params_names) param_values = map_results.x[:n_params] hyperparam_values = map_results.x[n_params:] # Print results as dictionary (to show params/hyperparams names too) map_results_dict = dict(zip(self.free_params_names, param_values)) map_hyperresults_dict = dict(zip(self.free_hyperparams_names, hyperparam_values)) print("MAP parameter results:", map_results_dict) print("MAP hyperparameter results:", map_hyperresults_dict) # Return the scipy OptimizeResult object so that user can inspect fully if needed return map_results
[docs] def generate_initial_walker_positions_random(self, nwalkers: int, verbose: bool = False, max_attempts: int = 1000) -> np.ndarray: """Generate random initial walker positions that satisfy priors and are astrophysically valid. Creates random starting positions for MCMC walkers by sampling from appropriate distributions based on each parameter's prior type. Ensures that parameter combinations are astrophysically valid (e.g., eccentricity < 1). Parameters ---------- nwalkers : int Number of MCMC walkers to generate positions for verbose : bool, default False If True, print walker positions during generation max_attempts : int, default 1000 Maximum attempts to generate a valid walker position Returns ------- np.ndarray Array of shape (nwalkers, ndim) where ndim is the number of free parameters + hyperparameters. Each row represents the starting position for one walker in the order of free_params_names + free_hyperparams_names. Raises ------ ValueError If a prior type is not supported for walker generation or if unable to generate valid positions after max_attempts Examples -------- >>> # Generate positions for 40 walkers >>> nwalkers = 10 * len(gpfitter.free_params_names + gpfitter.free_hyperparams_names) >>> initial_positions = gpfitter.generate_initial_walker_positions_random(nwalkers) >>> gpfitter.run_mcmc(initial_positions, nwalkers, max_steps=2000) """ if len(self.free_params_values) + len(self.free_hyperparams_values) == 0: raise ValueError( "Cannot generate walker positions: no free parameters or hyperparameters to sample. " "At least one parameter or hyperparameter must be set as free (fixed=False)." ) if verbose: print("Free parameters:", self.free_params_names) print("Free hyperparameters:", self.free_hyperparams_names) param_init = [] hyperparam_init = [] for walker_idx in range(nwalkers): attempts = 0 while attempts < max_attempts: param_walker_position = [] hyperparam_walker_position = [] # Generate parameter positions for param_name in self.free_params_names: # Check if we have a direct prior for this parameter # (because user may be fitting in a transformed parameterisation, but gave priors in the default parameterisation instead) if param_name in self.priors: prior = self.priors[param_name] if isinstance(prior, ravest.prior.Normal): param_walker_position.append(np.random.normal(loc=prior.mean, scale=2*prior.std)) elif isinstance(prior, ravest.prior.HalfNormal): param_walker_position.append(np.abs(np.random.normal(loc=0, scale=2*prior.std))) elif isinstance(prior, ravest.prior.Uniform): param_walker_position.append(np.random.uniform(low=prior.lower, high=prior.upper)) elif isinstance(prior, ravest.prior.TruncatedNormal): param_walker_position.append(np.random.uniform(low=prior.lower, high=prior.upper)) elif isinstance(prior, ravest.prior.Beta): param_walker_position.append(np.random.uniform(low=0, high=1)) elif isinstance(prior, ravest.prior.EccentricityUniform): param_walker_position.append(np.random.uniform(low=0, high=prior.upper)) else: raise ValueError(f"Unsupported prior type for walker generation: {type(prior)}") else: # No direct prior for this parameter (this happens if fitting in transformed parameterisation, but prior given in default) # use current value + small perturbation centre_val = self.params[param_name].value # Add small random perturbation (10% of current value + small fixed amount for near-zero values) perturbation = np.random.normal(0, abs(centre_val) * 0.1 + 0.01) param_walker_position.append(centre_val + perturbation) # Generate hyperparameter positions for hyperparam_name in self.free_hyperparams_names: if hyperparam_name in self.hyperpriors: hyperprior = self.hyperpriors[hyperparam_name] if isinstance(hyperprior, ravest.prior.Normal): hyperparam_walker_position.append(np.random.normal(loc=hyperprior.mean, scale=hyperprior.std)) elif isinstance(hyperprior, ravest.prior.HalfNormal): hyperparam_walker_position.append(np.abs(np.random.normal(loc=0, scale=hyperprior.std))) elif isinstance(hyperprior, ravest.prior.Uniform): hyperparam_walker_position.append(np.random.uniform(low=hyperprior.lower, high=hyperprior.upper)) elif isinstance(hyperprior, ravest.prior.TruncatedNormal): hyperparam_walker_position.append(np.random.uniform(low=hyperprior.lower, high=hyperprior.upper)) elif isinstance(hyperprior, ravest.prior.Beta): hyperparam_walker_position.append(np.random.uniform(low=hyperprior.a, high=hyperprior.b)) elif isinstance(hyperprior, ravest.prior.EccentricityUniform): hyperparam_walker_position.append(np.random.uniform(low=0, high=hyperprior.upper)) else: raise ValueError(f"Unsupported hyperprior type for walker generation: {type(hyperprior)}") # Check astrophysical validity and prior compliance try: # Convert walker position to full parameter dict (free + fixed) free_params_dict = dict(zip(self.free_params_names, param_walker_position)) all_params_dict = self.fixed_params_values_dict | free_params_dict # Check astrophysical validity self._validate_astrophysical_validity(all_params_dict) # Convert hyperparameter position to full hyperparameter dict (free + fixed) free_hyperparams_dict = dict(zip(self.free_hyperparams_names, hyperparam_walker_position)) all_hyperparams_dict = self.fixed_hyperparams_values_dict | free_hyperparams_dict # Check hyperparameter validity using GPKernel (internal method for float values) self.gp_kernel._validate_hyperparams_values(all_hyperparams_dict) # Check prior compliance using GPLogPosterior lp = GPLogPosterior( self.planet_letters, self.parameterisation, self.gp_kernel, self.priors, self.hyperpriors, self.fixed_params_values_dict, self.fixed_hyperparams_values_dict, self.free_params_names, self.free_hyperparams_names, self.time, self.vel, self.velerr, self.t0, self.instrument, self.unique_instruments, ) # Check the log-prior probability is finite (i.e. proposed initial values are within prior bounds) params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict) log_prior = lp.log_prior(params_for_prior) log_hyperprior = lp.log_hyperprior(free_hyperparams_dict) if not np.isfinite(log_prior): raise ValueError(f"Outside prior bounds (log_prior = {log_prior})") if not np.isfinite(log_hyperprior): raise ValueError(f"Outside hyperprior bounds (log_hyperprior = {log_hyperprior})") # If all validations pass, we have a valid walker position break except ValueError: # Validation failed, try again attempts += 1 continue if attempts >= max_attempts: raise ValueError(f"Could not generate astrophysically valid walker {walker_idx} after {max_attempts} attempts. " f"Consider relaxing priors/hyperpriors or checking parameter constraints.") if verbose: print(f"Walker {walker_idx} param position: {param_walker_position}") print(f"Walker {walker_idx} hyperparam position: {hyperparam_walker_position}") print(f"(valid after {attempts + 1} attempts)") param_init.append(param_walker_position) hyperparam_init.append(hyperparam_walker_position) param_init = np.array(param_init) hyperparam_init = np.array(hyperparam_init) # Combine parameter and hyperparameter positions into single array if hyperparam_init.size > 0: initial_positions = np.concatenate([param_init, hyperparam_init], axis=1) else: initial_positions = param_init if verbose: print(f"Generated MCMC initial param positions with shape: {param_init.shape}") print(f"Generated MCMC initial hyperparam positions with shape: {hyperparam_init.shape}") print(f"Combined initial positions shape: {initial_positions.shape}") return initial_positions
[docs] def generate_initial_walker_positions_around_point( self, centre: np.ndarray | list, nwalkers: int, scale: float = 1e-4, relative: bool = True, verbose: bool = False, max_attempts: int = 1000 ) -> np.ndarray: """Generate initial walker positions in a ball around a supplied centre point. Creates starting positions for MCMC walkers clustered around a centre point (e.g., MAP estimate). Each walker is generated by adding small random perturbations to the centre values. Validates that both the centre point and all generated walker positions satisfy priors/hyperpriors and are astrophysically valid. Parameters ---------- centre : np.ndarray or list Centre point for walker positions. Must have length equal to the number of free parameters + free hyperparameters and be in the order of free_params_names + free_hyperparams_names. nwalkers : int Number of MCMC walkers to generate positions for scale : float, default 1e-4 Scale of perturbations around centre point relative : bool, default True If True, perturbations scale with parameter values (scale * centre * random). If False, perturbations are absolute (scale * random). verbose : bool, default False If True, print walker positions during generation max_attempts : int, default 1000 Maximum attempts to generate a valid walker position Returns ------- np.ndarray Array of shape (nwalkers, ndim) where ndim is the number of free parameters + free hyperparameters. Each row represents the starting position for one walker in the order of free_params_names + free_hyperparams_names. Raises ------ ValueError If centre has wrong length, if centre point is invalid, or if unable to generate valid positions after max_attempts Examples -------- >>> # Generate walkers around MAP estimate >>> map_result = gpfitter.find_map_estimate() >>> initial_positions = gpfitter.generate_initial_walker_positions_around_point( ... centre=map_result.x, nwalkers=40, scale=1e-4 ... ) >>> gpfitter.run_mcmc(initial_positions, nwalkers=40, max_steps=2000) """ if len(self.free_params_values) + len(self.free_hyperparams_values) == 0: raise ValueError( "Cannot generate walker positions: no free parameters or hyperparameters to sample. " "At least one parameter or hyperparameter must be set as free (fixed=False)." ) centre = np.asarray(centre) expected_length = len(self.free_params_names) + len(self.free_hyperparams_names) if len(centre) != expected_length: raise ValueError( f"Centre must have length {expected_length} " f"({len(self.free_params_names)} free params + {len(self.free_hyperparams_names)} free hyperparams), " f"got {len(centre)}" ) if verbose: print("Free parameters:", self.free_params_names) print("Free hyperparameters:", self.free_hyperparams_names) print(f"Centre values: {centre}") # Split centre into params and hyperparams n_params = len(self.free_params_names) centre_params = centre[:n_params] centre_hyperparams = centre[n_params:] # Validate centre point first try: free_params_dict = dict(zip(self.free_params_names, centre_params)) all_params_dict = self.fixed_params_values_dict | free_params_dict # Check astrophysical validity self._validate_astrophysical_validity(all_params_dict) free_hyperparams_dict = dict(zip(self.free_hyperparams_names, centre_hyperparams)) all_hyperparams_dict = self.fixed_hyperparams_values_dict | free_hyperparams_dict # Check hyperparameter validity self.gp_kernel._validate_hyperparams_values(all_hyperparams_dict) # Check prior/hyperprior compliance lp = GPLogPosterior( self.planet_letters, self.parameterisation, self.gp_kernel, self.priors, self.hyperpriors, self.fixed_params_values_dict, self.fixed_hyperparams_values_dict, self.free_params_names, self.free_hyperparams_names, self.time, self.vel, self.velerr, self.t0, self.instrument, self.unique_instruments, ) params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict) log_prior = lp.log_prior(params_for_prior) log_hyperprior = lp.log_hyperprior(free_hyperparams_dict) if not np.isfinite(log_prior): raise ValueError(f"Centre point outside prior bounds (log_prior = {log_prior})") if not np.isfinite(log_hyperprior): raise ValueError(f"Centre point outside hyperprior bounds (log_hyperprior = {log_hyperprior})") if verbose: print(f"Centre point validated (log_prior = {log_prior}, log_hyperprior = {log_hyperprior})") except ValueError as e: raise ValueError(f"Supplied centre point is not valid: {e}") # Generate walker positions around centre param_init = [] hyperparam_init = [] if verbose and relative and np.any(centre == 0.0): all_names = self.free_params_names + self.free_hyperparams_names zero_names = [all_names[i] for i in range(len(centre)) if centre[i] == 0.0] print(f"Note: centre value is exactly 0.0 for {zero_names}; " f"using absolute perturbation (scale={scale}) for these parameters.") for walker_idx in range(nwalkers): attempts = 0 while attempts < max_attempts: # Generate perturbation random_vals = np.random.randn(len(centre)) if relative: # Relative perturbation: scales with parameter values. # When a centre value is exactly 0.0, the relative # perturbation (scale * randn * |0|) is always zero, # producing identical walker values in that dimension. # This causes emcee to reject the walkers as linearly # dependent (condition number check). Fall back to # absolute perturbation for those parameters. perturbation = np.empty(len(centre)) for i in range(len(centre)): if centre[i] == 0.0: perturbation[i] = scale * random_vals[i] else: perturbation[i] = scale * random_vals[i] * np.abs(centre[i]) else: # Absolute perturbation: same scale for all parameters perturbation = scale * random_vals walker_position = centre + perturbation walker_params = walker_position[:n_params] walker_hyperparams = walker_position[n_params:] # Validate this walker position try: free_params_dict = dict(zip(self.free_params_names, walker_params)) all_params_dict = self.fixed_params_values_dict | free_params_dict # Check astrophysical validity self._validate_astrophysical_validity(all_params_dict) free_hyperparams_dict = dict(zip(self.free_hyperparams_names, walker_hyperparams)) all_hyperparams_dict = self.fixed_hyperparams_values_dict | free_hyperparams_dict # Check hyperparameter validity self.gp_kernel._validate_hyperparams_values(all_hyperparams_dict) # Check prior/hyperprior compliance params_for_prior = lp._convert_params_for_prior_evaluation(free_params_dict) log_prior = lp.log_prior(params_for_prior) log_hyperprior = lp.log_hyperprior(free_hyperparams_dict) if not np.isfinite(log_prior): raise ValueError(f"Outside prior bounds (log_prior = {log_prior})") if not np.isfinite(log_hyperprior): raise ValueError(f"Outside hyperprior bounds (log_hyperprior = {log_hyperprior})") # If validation passes, we have a valid walker position break except ValueError: # Validation failed, try again attempts += 1 continue if attempts >= max_attempts: raise ValueError( f"Could not generate astrophysically valid walker {walker_idx} after {max_attempts} attempts. " f"Consider using a larger scale parameter or checking that the centre point is not too close to prior/physical boundaries." ) if verbose: print(f"Walker {walker_idx} param position: {walker_params}") print(f"Walker {walker_idx} hyperparam position: {walker_hyperparams}") print(f"(valid after {attempts + 1} attempts)") param_init.append(walker_params) hyperparam_init.append(walker_hyperparams) param_init = np.array(param_init) hyperparam_init = np.array(hyperparam_init) # Combine parameter and hyperparameter positions into single array if hyperparam_init.size > 0: initial_positions = np.concatenate([param_init, hyperparam_init], axis=1) else: initial_positions = param_init if verbose: print(f"Generated MCMC initial param positions with shape: {param_init.shape}") print(f"Generated MCMC initial hyperparam positions with shape: {hyperparam_init.shape}") print(f"Combined initial positions shape: {initial_positions.shape}") return initial_positions
[docs] def generate_initial_walker_positions_from_map( self, map_result: scipy.optimize.OptimizeResult, nwalkers: int, scale: float = 1e-4, relative: bool = True, verbose: bool = False, max_attempts: int = 1000 ) -> np.ndarray: """Generate initial walker positions around MAP estimate. Convenience function that generates walker positions clustered around MAP parameter and hyperparameter estimates from a pre-computed MAP result. Parameters ---------- map_result : scipy.optimize.OptimizeResult Result from find_map_estimate() nwalkers : int Number of MCMC walkers to generate positions for scale : float, default 1e-4 Scale of perturbations around MAP values relative : bool, default True If True, perturbations scale with parameter values. If False, perturbations are absolute. verbose : bool, default False If True, print walker positions during generation max_attempts : int, default 1000 Maximum attempts to generate a valid walker position Returns ------- np.ndarray Array of shape (nwalkers, ndim) where ndim is the number of free parameters + free hyperparameters. Each row represents the starting position for one walker. Raises ------ ValueError If unable to generate valid positions Examples -------- >>> # Find MAP then generate walkers around it >>> map_result = gpfitter.find_map_estimate() >>> initial_positions = gpfitter.generate_initial_walker_positions_from_map( ... map_result=map_result, nwalkers=40 ... ) >>> gpfitter.run_mcmc(initial_positions, nwalkers=40, max_steps=2000) """ if len(self.free_params_values) + len(self.free_hyperparams_values) == 0: raise ValueError( "Cannot generate walker positions: no free parameters or hyperparameters to sample. " "At least one parameter or hyperparameter must be set as free (fixed=False)." ) return self.generate_initial_walker_positions_around_point( centre=map_result.x, nwalkers=nwalkers, scale=scale, relative=relative, verbose=verbose, max_attempts=max_attempts )
[docs] def run_mcmc(self, initial_positions: np.ndarray, nwalkers: int, max_steps: int = 5000, progress: bool = True, multiprocessing: bool = False, check_convergence: bool = False, convergence_check_interval: int = 1000, convergence_check_start: int = 0) -> None: """Run MCMC sampling from given initial walker positions. Parameters ---------- initial_positions : np.ndarray Starting positions for all MCMC walkers. Shape must be (nwalkers, ndim) where ndim is the number of free parameters + hyperparameters. Each row represents the starting position for one walker in the order of free_params_names + free_hyperparams_names. nwalkers : int Number of MCMC walkers (must match first dimension of initial_positions) max_steps : int, optional Maximum number of MCMC steps to run. If check_convergence=False, runs for exactly this many steps. If check_convergence=True, runs up to this many steps, stopping early when convergence criteria are met (default: 5000) progress : bool, optional Whether to show progress bar during MCMC (default: True) multiprocessing : bool, optional Whether to use multiprocessing for MCMC (default: False) check_convergence : bool, optional If True, check for convergence and stop early when criteria met. Convergence requires: chain length > 50 times max autocorrelation time, and autocorrelation time estimate stable to 1 percent. If False, run for exactly max_steps (default: False) convergence_check_interval : int, optional Steps between convergence checks (only used if check_convergence=True) (default: 1000) convergence_check_start : int, optional Minimum iteration before starting convergence checks. Set this sensibly (e.g. 2x burn-in) to avoid inaccurate tau estimates on short chains (default: 0) """ if len(self.free_params_values) + len(self.free_hyperparams_values) == 0: raise ValueError( "Cannot run MCMC: no free parameters or hyperparameters to sample. " "At least one parameter or hyperparameter must be set as free (fixed=False)." ) # Initialize log-posterior object for MCMC sampling gp_lp = GPLogPosterior( self.planet_letters, self.parameterisation, self.gp_kernel, self.priors, self.hyperpriors, self.fixed_params_values_dict, self.fixed_hyperparams_values_dict, self.free_params_names, self.free_hyperparams_names, self.time, self.vel, self.velerr, self.t0, self.instrument, self.unique_instruments, ) # Enforce minimum number of walkers (though users ideally should have many more than this) if nwalkers < 2 * self.ndim: logging.warning(f"nwalkers should be at least 2 * ndim. You have {nwalkers} walkers and {self.ndim} dimensions. Setting nwalkers to {2 * self.ndim}.") self.nwalkers = 2 * self.ndim else: self.nwalkers = nwalkers # Validate walker positions shape if initial_positions.shape != (nwalkers, self.ndim): raise ValueError(f"initial_positions must have shape ({nwalkers}, {self.ndim}), got {initial_positions.shape}") # Validate every walker position for astrophysical validity and prior compliance # (we don't want to start any chains in invalid parameter space) n_params = len(self.free_params_names) for i, walker_position in enumerate(initial_positions): # Split walker position into parameters and hyperparameters param_position = walker_position[:n_params] hyperparam_position = walker_position[n_params:] if n_params < len(walker_position) else [] walker_params_dict = dict(zip(self.free_params_names, param_position)) walker_hyperparams_dict = dict(zip(self.free_hyperparams_names, hyperparam_position)) all_params_dict = self.fixed_params_values_dict | walker_params_dict all_hyperparams_dict = self.fixed_hyperparams_values_dict | walker_hyperparams_dict # Check astrophysical validity try: self._validate_astrophysical_validity(all_params_dict) except ValueError as e: raise ValueError(f"Walker {i} has invalid astrophysical parameters: {e}") from e # Check hyperparameter validity using GPKernel (internal method for float values) try: self.gp_kernel._validate_hyperparams_values(all_hyperparams_dict) except ValueError as e: raise ValueError(f"Walker {i} has invalid hyperparameters: {e}") from e # Check prior compliance params_for_prior = gp_lp._convert_params_for_prior_evaluation(walker_params_dict) log_prior = gp_lp.log_prior(params_for_prior) if not np.isfinite(log_prior): raise ValueError(f"Walker {i} is outside prior bounds (log_prior = {log_prior})") # Check hyperprior compliance log_hyperprior = gp_lp.log_hyperprior(walker_hyperparams_dict) if not np.isfinite(log_hyperprior): raise ValueError(f"Walker {i} is outside hyperprior bounds (log_hyperprior = {log_hyperprior})") # Combine parameter names for sampler (needs it as one argument) all_param_names = self.free_params_names + self.free_hyperparams_names # TODO: parameter_names argument does slightly impact performance - but not sure if it can be avoided, we do need the names # and I'm not sure constructing the dictionary later ourselves manually is any quicker than passing parameter_names argument # Create sampler if multiprocessing: pool = mp.get_context("spawn").Pool() # Use 'spawn' instead of 'fork' to avoid issues on some Linux platforms sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, gp_lp.log_probability, parameter_names=all_param_names, pool=pool) else: sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, gp_lp.log_probability, parameter_names=all_param_names) # Warn if convergence arguments provided but convergence checking disabled if not check_convergence: if convergence_check_interval != 1000 or convergence_check_start != 0: logging.warning( "Convergence checking arguments provided but check_convergence=False. " "These arguments will be ignored. Did you forget to set check_convergence=True?" ) # Run MCMC with or without convergence checking if not check_convergence: # Fixed-length mode - run for exactly max_steps logging.info(f"Starting MCMC for {max_steps} steps...") sampler.run_mcmc(initial_state=initial_positions, nsteps=max_steps, progress=progress) logging.info("...MCMC done.") else: # Convergence checking - run up to max_steps, stopping early if converged logging.info(f"Starting MCMC with convergence checks. (Maximum {max_steps} steps, checking convergence every {convergence_check_interval} steps after iteration {convergence_check_start})...") # Initialize autocorrelation history storage self.autocorr_history = {} old_tau = np.inf for sample in sampler.sample(initial_state=initial_positions, iterations=max_steps, progress=progress): # Only check at specified intervals if sampler.iteration % convergence_check_interval != 0: continue # Don't check before we have reached convergence_check_start if sampler.iteration < convergence_check_start: continue # Get autocorrelation time estimate tau = sampler.get_autocorr_time(tol=0) # Store autocorrelation history for plotting/diagnostics later self.autocorr_history[sampler.iteration] = tau.copy() # Log progress logging.info(f"Convergence check: Step {sampler.iteration}: mean(tau)={np.mean(tau):.1f}, max(tau)={np.max(tau):.1f}") # Check convergence criteria check_chain_length = np.all(sampler.iteration > 50 * tau) # Chain length > 50 * tau check_stable_tau = np.all(np.abs(old_tau - tau) / tau < 0.01) # Tau stable to 1 percent converged = check_chain_length and check_stable_tau if converged: logging.info(f"Converged at iteration {sampler.iteration}") break else: logging.info(f"Not yet converged (N/50>tau check: {check_chain_length}, tau stability check: {check_stable_tau})") # Warn if approaching max steps without convergence if sampler.iteration > 0.8 * max_steps: logging.warning(f"Approaching max iterations ({max_steps}) without convergence! (max tau={np.max(tau):.1f}, tau stability change={np.abs(old_tau - tau) / tau})") # Update old tau for next check old_tau = tau # Final log final_steps = sampler.iteration logging.info(f"MCMC complete: {final_steps} steps total") # Close multiprocessing pool if used if multiprocessing: pool.close() pool.join() self.sampler = sampler
[docs] def get_samples_np(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, flat: bool = False) -> np.ndarray: """Return a contiguous numpy array of MCMC samples. Samples can be discarded from the start and/or the end of the array. You can also thin (take only every n-th sample), and you can flatten the array so that each walker's chain is merged into one chain. This is the foundational method for accessing MCMC samples. All the other get_samples methods build on this. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) flat : bool, optional Whether to flatten each walker's chain into one chain. (default: False) If True, return flattened array with shape (nsteps_after_discard_thin * nwalkers, ndim) If False, return unflattened array with shape (nsteps_after_discard_thin, nwalkers, ndim) Returns ------- np.ndarray Contiguous array of MCMC samples. Shape depends on `flat` parameter: - flat=False: (nsteps_after_discard_thin, nwalkers, ndim) - flat=True: (nsteps_after_discard_thin * nwalkers, ndim) Notes ----- We enforce np.ascontiguousarray() on the return, because np.reshape() does not guarantee a contiguous array in memory. """ # Get the full chain from emcee without any processing full_samples = self.sampler.get_chain(discard=0, thin=1, flat=False) # Match emcee's slicing logic: [discard + thin - 1 : end : thin] # But adapted - we also allow for discarding from the end start_idx = discard_start + thin - 1 if discard_end == 0: end_idx = full_samples.shape[0] else: end_idx = full_samples.shape[0] - discard_end # Check the start and end points are valid if start_idx >= end_idx: raise ValueError(f"Invalid parameters: start_idx ({start_idx}) >= end_idx ({end_idx}). " f"Try reducing discard_start ({discard_start}), discard_end ({discard_end}), or thin ({thin}).") # Apply the slicing samples = full_samples[start_idx:end_idx:thin] # Flatten if requested (after discarding) - flatten steps and walkers into single dimension if flat: # (steps, walkers, ndim) -> (steps*walkers, ndim) nsteps, nwalkers, ndim = samples.shape samples = samples.reshape(nsteps * nwalkers, ndim) return np.ascontiguousarray(samples)
[docs] def get_samples_df(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> pd.DataFrame: """Return a pandas DataFrame of flattened MCMC samples. Each row represents one sample, each column represents one parameter or hyperparameter. Built on get_samples_np(). Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) Returns ------- pd.DataFrame DataFrame with shape (nsteps_after_discard_thin * nwalkers, ndim). Columns are parameter and hyperparameter names. """ flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) return pd.DataFrame(flat_samples, columns=self.free_params_names + self.free_hyperparams_names)
[docs] def get_samples_dict(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> Dict[str, np.ndarray]: """Return a dict of flattened MCMC samples. Each parameter and hyperparameter gets a 1D (flattened) contiguous array of all its samples. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) Returns ------- dict Dictionary mapping parameter and hyperparameter names to 1D arrays of samples. Each array has shape (nsteps_after_discard_thin * nwalkers,) Examples -------- >>> samples_dict = gpfitter.get_samples_dict(discard_start=1000) >>> K_b_samples = samples_dict['K_b'] # All samples for parameter K for planet b >>> gp_amp_samples = samples_dict['gp_amp'] # All samples for GP amplitude """ flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) param_names = self.free_params_names + self.free_hyperparams_names # Direct numpy slicing - much faster than pandas operations return {name: flat_samples[:, i] for i, name in enumerate(param_names)}
[docs] def get_sampler_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, flat: bool = False) -> np.ndarray: """Returns the log probability at each step of the sampler. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) flat : bool, optional If True, return flattened array shape (nsteps_after_discard_thin * nwalkers) If False, return unflattened array shape (nsteps_after_discard_thin, nwalkers) (default: False) Returns ------- np.ndarray Array of log probabilities of the function at each sample. """ # Get the full log prob chain from emcee without any processing full_lnprob = self.sampler.get_log_prob(discard=0, thin=1, flat=False) # Match emcee's slicing logic: [discard + thin - 1 : end : thin] # But adapted - we also allow for discarding from the end start_idx = discard_start + thin - 1 if discard_end == 0: end_idx = full_lnprob.shape[0] else: end_idx = full_lnprob.shape[0] - discard_end # Check the start and end points are valid if start_idx >= end_idx: raise ValueError(f"Invalid parameters: start_idx ({start_idx}) >= end_idx ({end_idx}). " f"Try reducing discard_start ({discard_start}), discard_end ({discard_end}), or thin ({thin}).") # Apply the slicing lnprob = full_lnprob[start_idx:end_idx:thin] # Flatten if requested (after discarding) - flatten steps and walkers into single dimension if flat: # (steps, walkers) -> (steps*walkers,) nsteps, nwalkers = lnprob.shape lnprob = lnprob.reshape(nsteps * nwalkers) return np.ascontiguousarray(lnprob)
[docs] def get_mcmc_posterior_dict(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> dict: """Return dict combining fixed parameter values, and MCMC samples for the free ones. This method creates a unified dictionary containing all model parameters: fixed parameters as single float values, and free parameters as arrays of MCMC samples. This format is ideal for functions like calculate_mpsini that need all parameters (whether free or fixed), and that should propagate uncertainties from the free parameters samples. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) Returns ------- dict Dictionary of all parameters and hyperparameters: - Fixed parameters: single float values - Free parameters: 1D arrays of MCMC samples with shape (nsteps_after_discard_thin * nwalkers,) - Fixed hyperparameters: single float values - Free hyperparameters: 1D arrays of MCMC samples with shape (nsteps_after_discard_thin * nwalkers,) """ fixed_params_dict = self.fixed_params_values_dict fixed_hyperparams_dict = self.fixed_hyperparams_values_dict free_samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin) return fixed_params_dict | fixed_hyperparams_dict | free_samples_dict
[docs] def calculate_log_likelihood(self, params_hyperparams_dict: Dict[str, float]) -> float: """Calculate log-likelihood for given parameter and hyperparameter values. Note this does not include (log-)prior probabilities, this is just the (log-) *likelihood* primarily for use in AICc & BIC calculation. Parameters ---------- params_hyperparams_dict : dict Dictionary of all parameter and hyperparameter values (both fixed and free) Returns ------- float The log-likelihood value """ # Create GPLogLikelihood object (same as in find_map_estimate and run_mcmc) gp_log_likelihood = GPLogLikelihood( time=self.time, vel=self.vel, velerr=self.velerr, t0=self.t0, instrument=self.instrument, unique_instruments=self.unique_instruments, planet_letters=self.planet_letters, parameterisation=self.parameterisation, gp_kernel=self.gp_kernel, ) # Split combined dict into params and hyperparams all_param_names = self.free_params_names + self.fixed_params_names all_hyperparam_names = self.free_hyperparams_names + self.fixed_hyperparams_names params = {name: params_hyperparams_dict[name] for name in all_param_names} hyperparams = {name: params_hyperparams_dict[name] for name in all_hyperparam_names} return gp_log_likelihood(params=params, hyperparams=hyperparams)
[docs] def build_params_dict(self, free_params_hyperparams: np.ndarray | list | Dict[str, float]) -> Dict[str, float]: """Build complete parameter dictionary by combining free and fixed parameters and hyperparameters. Takes free parameter and hyperparameter values from various sources (MAP results, MCMC samples, or custom values) and combines them with the fixed parameter and hyperparameter values to create a complete dictionary suitable for calculating log-likelihood, chi2, AICc, and BIC. Parameters ---------- free_params_hyperparams : list, np.ndarray, or dict Free parameter and hyperparameter values from any source: - list/array: values in order of self.free_params_names + self.free_hyperparams_names - dict: mapping of free param/hyperparam names to values Returns ------- Dict[str, float] Complete parameters and hyperparameters dict with both free and fixed values Examples -------- >>> # From MAP optimization result >>> map_result = gpfitter.find_map_estimate() >>> params = gpfitter.build_params_dict(map_result.x) >>> aicc = gpfitter.calculate_aicc(params) >>> >>> # From best MCMC sample >>> best_sample = gpfitter.get_sample_with_best_lnprob(discard_start=1000) >>> params = gpfitter.build_params_dict(best_sample) >>> bic = gpfitter.calculate_bic(params) >>> >>> # From custom array (params then hyperparams, in order of names) >>> custom_values = [5.0, 50.0, 0.1, 0.0, 2450000.0, # params ... 10.0, 5.0, 0.5, 30.0] # hyperparams >>> params = gpfitter.build_params_dict(custom_values) >>> log_like = gpfitter.calculate_log_likelihood(params) """ if isinstance(free_params_hyperparams, dict): # Validate that all expected free parameters and hyperparameters are present expected_params = set(self.free_params_names) expected_hyperparams = set(self.free_hyperparams_names) expected_names = expected_params | expected_hyperparams provided_names = set(free_params_hyperparams.keys()) missing = expected_names - provided_names if missing: raise ValueError(f"Missing required free parameters/hyperparameters: {missing}") extra = provided_names - expected_names if extra: raise ValueError(f"Unexpected parameters/hyperparameters provided: {extra}") return self.fixed_params_values_dict | self.fixed_hyperparams_values_dict | free_params_hyperparams else: # Validate that array/list has correct length expected_length = len(self.free_params_names) + len(self.free_hyperparams_names) if len(free_params_hyperparams) != expected_length: raise ValueError( f"Expected {expected_length} free parameter and hyperparameter values " f"but got {len(free_params_hyperparams)} " f"(expecting values for {self.free_params_names} + {self.free_hyperparams_names})" ) all_free_names = self.free_params_names + self.free_hyperparams_names free_dict = dict(zip(all_free_names, free_params_hyperparams)) return self.fixed_params_values_dict | self.fixed_hyperparams_values_dict | free_dict
[docs] @staticmethod @jax.jit def _compute_gp_chi2( kernel: kernels.Kernel, time_array: jnp.ndarray, verr_squared_array: jnp.ndarray, residuals: jnp.ndarray ) -> float: """JIT-compiled GP chi-squared computation. Calculates chi^2 = r^T @ K^(-1) @ r where K is the full covariance matrix. Parameters ---------- kernel : tinygp.kernels.Kernel GP kernel time_array : jnp.ndarray Observation times verr_squared_array : jnp.ndarray Observational variances (error^2 + jitter^2) residuals : jnp.ndarray Data - mean_model Returns ------- float Chi-squared value incorporating full covariance structure Notes ----- This uses the efficient Cholesky-based method. The mathematically equivalent (but much slower) direct computation would be: # K = kernel(time_array, time_array) + jnp.diag(verr_squared_array) # K_inv = jnp.linalg.inv(K) # chi2 = jnp.dot(residuals, jnp.dot(K_inv, residuals)) Instead, we use the Cholesky decomposition K = L @ L^T: - alpha = L^(-1) @ residuals (via _get_alpha) - chi^2 = alpha^T @ alpha = r^T @ K^(-1) @ r """ gp = GaussianProcess(kernel=kernel, X=time_array, diag=verr_squared_array) alpha = gp._get_alpha(residuals) # L^(-1) @ residuals, where K = L @ L^T return jnp.dot(alpha, alpha) # r^T @ K^(-1) @ r = alpha^T @ alpha
[docs] def calculate_chi2(self, params_hyperparams_dict: Dict[str, float]) -> float: r"""Calculate chi-squared for given parameter and hyperparameter values. For GP models, this calculates: .. math:: \chi^2 = \mathbf{r}^T \mathbf{K}^{-1} \mathbf{r} where :math:`\mathbf{K}` is the full covariance matrix including GP kernel and observational uncertainties. This properly accounts for correlated noise structure. Uses GPLogLikelihood._calculate_mean_model to avoid code duplication. Parameters ---------- params_hyperparams_dict : dict Dictionary of all parameter and hyperparameter values (both fixed and free) Returns ------- float Chi-squared value """ # Create GPLogLikelihood instance to reuse mean model calculation gp_ll = GPLogLikelihood( time=self.time, vel=self.vel, velerr=self.velerr, t0=self.t0, instrument=self.instrument, unique_instruments=self.unique_instruments, planet_letters=self.planet_letters, parameterisation=self.parameterisation, gp_kernel=self.gp_kernel ) # Extract parameter and hyperparameter dicts all_param_names = self.free_params_names + self.fixed_params_names params = {name: params_hyperparams_dict[name] for name in all_param_names} all_hyperparam_names = self.free_hyperparams_names + self.fixed_hyperparams_names hyperparams = {name: params_hyperparams_dict[name] for name in all_hyperparam_names} # Calculate mean model using GPLogLikelihood method mean_model = gp_ll._calculate_mean_model(params) # Calculate residuals residuals = gp_ll.jax_vel - mean_model # Build GP kernel with hyperparameters kernel = self.gp_kernel.build_kernel(hyperparams) # Add jitter to observational uncertainties # jit_value = params["jit"] # jit2_verr2 = gp_ll.jax_verr**2 + jit_value**2 velerr_jitter_squared = np.zeros_like(self.velerr) for inst in self.unique_instruments: mask = (self.instrument == inst) jit = params_hyperparams_dict[f"jit_{inst}"] velerr_jitter_squared[mask] = self.velerr[mask]**2 + jit**2 # Calculate chi^2 = r^T K^(-1) r using full covariance matrix return float(self._compute_gp_chi2(kernel, gp_ll.jax_time, velerr_jitter_squared, residuals))
[docs] def calculate_aicc(self, params_hyperparams_dict: Dict[str, float]) -> float: r"""Calculate corrected Akaike Information Criterion (AICc). .. math:: \text{AICc} = 2k - 2\ln\mathcal{L} + \frac{2k^2 + 2k}{n - k - 1} where :math:`k` is the number of free parameters and hyperparameters, :math:`n` is the number of observations, and :math:`\mathcal{L}` is the likelihood. Converges to AIC for large :math:`n`. Parameters ---------- params_hyperparams_dict : dict Dictionary of all parameter and hyperparameter values (both fixed and free) Returns ------- float AICc value """ k = self.ndim n = len(self.time) log_like = self.calculate_log_likelihood(params_hyperparams_dict) aic = 2 * k - 2 * log_like # traditional AIC correction = (2 * k**2 + 2 * k) / (n - k - 1) # small-sample correction return aic + correction
[docs] def calculate_bic(self, params_hyperparams_dict: Dict[str, float]) -> float: r"""Calculate Bayesian Information Criterion (BIC) for given parameters and hyperparameters. .. math:: \text{BIC} = k \ln n - 2 \ln \mathcal{L} where :math:`k` is the number of free parameters and hyperparameters, :math:`n` is the number of observations, and :math:`\mathcal{L}` is the likelihood. Parameters ---------- params_hyperparams_dict : dict Dictionary of all parameter and hyperparameter values (both fixed and free) Returns ------- float BIC value """ log_like = self.calculate_log_likelihood(params_hyperparams_dict) return self.ndim * np.log(len(self.time)) - 2 * log_like
[docs] def get_sample_with_best_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1) -> Dict[str, float]: """Get parameter and hyperparameter values from the MCMC sample with the highest log probability. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) Returns ------- Dict[str, float] Dictionary of parameter and hyperparameter names to values from the best sample """ # Get samples and log probabilities samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) lnprob = self.get_sampler_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) # Find index of maximum log probability best_idx = np.argmax(lnprob) best_lnprob = lnprob[best_idx] print(f"Best sample found with log probability {best_lnprob:.6f} at index {best_idx} of samples (with discard_start={discard_start}, discard_end={discard_end}, thin={thin})") # Get parameter values at that index best_values = samples[best_idx] # Return as dictionary (samples include both params and hyperparams) all_param_names = self.free_params_names + self.free_hyperparams_names return dict(zip(all_param_names, best_values))
[docs] def plot_autocorr_estimates( self, params: list[str] | None = None, hyperparams: list[str] | None = None, plot_mean: bool = False, show_legend: bool = True, title: str | None = "Autocorrelation Time Estimates", xlabel: str | None = "Step number", ylabel: str | None = r"Autocorrelation time $\tau$", save: bool = False, fname: str = "autocorr_plot.png", dpi: int = 100 ) -> None: r"""Plot autocorrelation time estimates from adaptive MCMC run. Shows how autocorrelation time evolved during the MCMC run and the convergence threshold line (N / 50). Only available if run_mcmc was called with check_convergence=True. Parameters ---------- params : list[str] or None, optional List of parameter names to plot. If None, plots all free parameters (default: None) hyperparams : list[str] or None, optional List of hyperparameter names to plot. If None, plots all free hyperparameters (default: None) plot_mean : bool, optional If True, plot mean tau instead of individual parameter/hyperparameter taus. Overrides params and hyperparams arguments (default: False) show_legend : bool, optional Whether to show legend (default: True) title : str or None, optional Plot title (default: "Autocorrelation Time Estimates"). Set to None or "" to skip. xlabel : str or None, optional X-axis label (default: "Step number"). Set to None or "" to skip. ylabel : str or None, optional Y-axis label (default: r"Autocorrelation time $\tau$"). Set to None or "" to skip. save : bool, optional Save the plot to path `fname` (default: False) fname : str, optional The path to save the plot to (default: "autocorr_plot.png") dpi : int, optional The dpi to save the image at (default: 100) Raises ------ ValueError If no autocorrelation history is available (run_mcmc was not called with check_convergence=True, or has not been called yet) """ # Check if data available if not hasattr(self, 'autocorr_history') or len(self.autocorr_history) == 0: raise ValueError( "No autocorrelation history available. " "Please run run_mcmc() with check_convergence=True first." ) iterations = np.array(list(self.autocorr_history.keys())) max_iteration = np.max(iterations) tau_history = np.array(list(self.autocorr_history.values())) # Shape: (n_checks, n_params + n_hyperparams) # Create plot fig, ax = plt.subplots(1, figsize=(10, 6)) if title: fig.suptitle(title) # Plot convergence threshold (N/50) ax.plot([0, max_iteration], [0, max_iteration / 50], "--k", linewidth=2, label="N/50") if plot_mean: # Plot mean tau mean_tau = np.mean(tau_history, axis=1) ax.plot(iterations, mean_tau, linewidth=2, label="Mean τ") else: # Determine which parameters/hyperparameters to plot names_to_plot = [] indices_to_plot = [] # Handle parameters if params is None: # Include all free params for i, param_name in enumerate(self.free_params_names): names_to_plot.append(param_name) indices_to_plot.append(i) else: # Include only specified params for param in params: if param in self.free_params_names: idx = self.free_params_names.index(param) names_to_plot.append(param) indices_to_plot.append(idx) else: logging.warning(f"Parameter '{param}' not found in free parameters, skipping") # Handle hyperparameters if hyperparams is None: # Include all free hyperparams for i, hyperparam_name in enumerate(self.free_hyperparams_names): names_to_plot.append(hyperparam_name) indices_to_plot.append(len(self.free_params_names) + i) else: # Include only specified hyperparams for hyperparam in hyperparams: if hyperparam in self.free_hyperparams_names: idx = len(self.free_params_names) + self.free_hyperparams_names.index(hyperparam) names_to_plot.append(hyperparam) indices_to_plot.append(idx) else: logging.warning(f"Hyperparameter '{hyperparam}' not found in free hyperparameters, skipping") # Plot individual parameter/hyperparameter taus for idx, name in zip(indices_to_plot, names_to_plot): ax.plot(iterations, tau_history[:, idx], alpha=0.7, label=param_key_to_latex(name)) ax.set_xlim(0, iterations.max()) ax.set_ylim(bottom=0) if xlabel: ax.set_xlabel(xlabel) if ylabel: ax.set_ylabel(ylabel) if show_legend: ax.legend(loc='upper left') ax.grid(True, alpha=0.3) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def plot_chains(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, truths: list = None, title: str | None = "Chains plot", xlabel: str | None = "Step number", save: bool = False, fname: str = "chains_plot.png", dpi: int = 100) -> None: """Plot MCMC chains for all free parameters and hyperparameters. Displays the evolution of each free parameter and hyperparameter across MCMC steps for all walkers. Useful for diagnosing convergence, burn-in, and mixing of the MCMC chains. Each parameter/hyperparameter gets its own subplot showing all walker traces. For GP fitting, this includes both planetary/trend parameters and GP kernel hyperparameters. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) truths : list, optional List of true parameter/hyperparameter values to overplot as horizontal lines. Must match the number of free parameters + hyperparameters. Use None for parameters without known truth values (default: None) title : str or None, optional Plot title (default: "Chains plot"). Set to None or "" to skip. xlabel : str or None, optional X-axis label (default: "Step number"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "chains_plot.png") dpi : int, optional Resolution for saving (default: 100) """ # Scale figure height to maintain consistent subplot size subplot_height_inches = 1.25 fig, axes = plt.subplots(self.ndim, figsize=(10, self.ndim * subplot_height_inches), sharex=True, constrained_layout=True) if title: fig.suptitle(title) if self.ndim == 1: axes = [axes] if truths is not None: if not len(truths) == self.ndim: raise ValueError(f"Length of truths ({len(truths)}) must match number of free parameters and hyperparameters ({self.ndim})") samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=False) param_names = self.free_params_names + self.free_hyperparams_names for i in range(self.ndim): ax = axes[i] to_plot = samples[:, :, i] ax.plot(to_plot, "k", alpha=0.3) ax.set_xlim(0, len(samples)) ax.set_ylabel(param_key_to_latex(param_names[i])) if truths is not None and truths[i] is not None: ax.axhline(truths[i], color="tab:blue") fig.align_ylabels(axes) if xlabel: axes[-1].set_xlabel(xlabel) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def plot_lnprob(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = "Log Probability Traces", xlabel: str | None = "Step number", ylabel: str | None = "Log probability", save: bool = False, fname: str = "lnprob_plot.png", dpi: int = 100) -> None: """Plot log probability traces for all walkers. Useful for diagnosing MCMC convergence and identifying problematic walkers/parameters. You can use `discard_start` and `discard_end` to focus in on specific steps in the chains. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) title : str or None, optional Plot title (default: "Log Probability Traces"). Set to None or "" to skip. xlabel : str or None, optional X-axis label (default: "Step number"). Set to None or "" to skip. ylabel : str or None, optional Y-axis label (default: "Log probability"). Set to None or "" to skip. save : bool, optional Save the plot to path `fname` (default: False) fname : str, optional The path to save the plot to (default: "lnprob_plot.png") dpi : int, optional The dpi to save the image at (default: 100) """ fig, ax = plt.subplots(1, figsize=(10, 6)) if title: fig.suptitle(title) lnprobs = self.get_sampler_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=False) nsteps, nwalkers = lnprobs.shape for i in range(nwalkers): to_plot = lnprobs[:, i] ax.plot(to_plot, "k", alpha=0.3) ax.set_xlim(0, nsteps) if xlabel: ax.set_xlabel(xlabel) if ylabel: ax.set_ylabel(ylabel) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def plot_corner(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, plot_datapoints: bool = False, truths: list[float] = None, title: str | None = "Corner plots", save: bool = False, fname: str = "corner_plot.png", dpi: int = 100) -> None: """Create a corner plot of MCMC samples. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) plot_datapoints : bool, optional Show individual data points in addition to contours (default: False) truths : list of float, optional True parameter values to overplot as vertical/horizontal lines (default: None). Must match the order of free parameters if provided. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "corner_plot.png") dpi : int, optional Resolution for saving (default: 100) """ flat_samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) param_labels = [param_key_to_latex(n) for n in self.free_params_names + self.free_hyperparams_names] fig = corner.corner( flat_samples, labels=param_labels, show_titles=True, plot_datapoints=plot_datapoints, quantiles=[0.1585, 0.5, 0.8415], truths=truths, ) fig.suptitle("Corner plots") if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def _plot_rv(self, params_hyperparams: Dict[str, float], title: str = "RV Model", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "rv_plot.png", dpi: int = 100) -> None: """Helper function to plot RV model with given parameters. For GP fitting, this plots both the mean model (planets + trend) and the GP component. The GP component is predicted at smooth time points using the posterior conditioned on observed data residuals, including 1-sigma uncertainty bands. Parameters ---------- params : dict Dictionary of parameter values (both free and fixed) title : str, optional Plot title (default: "RV Model"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "rv_plot.png") dpi : int, optional Resolution for saving (default: 100) """ # Step 1: RV at smooth time points, for plotting _tmin, _tmax = self.time.min(), self.time.max() _trange = _tmax - _tmin tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000) # Calculate mean model (planets + trend) at smooth time points (no gamma - that's per-instrument) rv_mean_smooth = np.zeros(len(tsmooth)) # Add each Planet's RV at smooth times for letter in self.planet_letters: planet_params = {} for par_name in self.parameterisation.pars: key = f"{par_name}_{letter}" planet_params[par_name] = params_hyperparams[key] planet = ravest.model.Planet(letter, self.parameterisation, planet_params) rv_mean_smooth += planet.radial_velocity(tsmooth) # Add the system Trend at smooth times (gd, gdd only - no gamma) trend_params = {"gd": params_hyperparams["gd"], "gdd": params_hyperparams["gdd"]} trend = ravest.model.Trend(params=trend_params, t0=self.t0) rv_trend_smooth = trend.radial_velocity(tsmooth) rv_mean_smooth += rv_trend_smooth # Step 2: RV at observed times only, for GP conditioning # Calculate mean model (planets + trend) at observed times (no gamma) rv_mean_obs = np.zeros(len(self.time)) # Add each Planet's RV at observed times for letter in self.planet_letters: planet_params = {} for par_name in self.parameterisation.pars: key = f"{par_name}_{letter}" planet_params[par_name] = params_hyperparams[key] planet = ravest.model.Planet(letter, self.parameterisation, planet_params) rv_mean_obs += planet.radial_velocity(self.time) # Add the system Trend at observed times rv_mean_obs += trend.radial_velocity(self.time) # Subtract per-instrument gamma offsets from data (so residuals are gamma-corrected) vel_corrected = self.vel.copy() for inst in self.unique_instruments: mask = (self.instrument == inst) vel_corrected[mask] -= params_hyperparams[f"g_{inst}"] # Step 3: Set up GP for prediction hyperparams = {hp: params_hyperparams[hp] for hp in self.gp_kernel.expected_hyperparams} kernel = self.gp_kernel.build_kernel(hyperparams) # Calculate per-instrument jitter for GP diagonal jit2_verr2 = np.zeros(len(self.time)) for inst in self.unique_instruments: mask = (self.instrument == inst) jit = params_hyperparams[f"jit_{inst}"] jit2_verr2[mask] = self.velerr[mask]**2 + jit**2 # Create GP conditioned on gamma-corrected residuals residuals_obs = vel_corrected - rv_mean_obs gp_obs = GaussianProcess(kernel=kernel, X=jnp.array(self.time), diag=jnp.array(jit2_verr2)) # Condition GP on observed residuals and predict at smooth time points _, gp_cond_smooth = gp_obs.condition(y=jnp.array(residuals_obs), X_test=jnp.array(tsmooth)) rv_gp_mean_smooth = gp_cond_smooth.mean rv_gp_std_smooth = np.sqrt(gp_cond_smooth.variance) _, gp_cond_obs = gp_obs.condition(y=jnp.array(residuals_obs), X_test=jnp.array(self.time)) rv_gp_mean_obs = gp_cond_obs.mean np.sqrt(gp_cond_obs.variance) # Step 4: generate the plot # Total model is mean + GP prediction rv_total_smooth = rv_mean_smooth + rv_gp_mean_smooth # Calculate per-instrument jitter for error bars velerr_with_jit = np.zeros_like(self.velerr) for inst in self.unique_instruments: mask = (self.instrument == inst) jit = params_hyperparams[f"jit_{inst}"] velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit**2) # Calculate residuals for residuals subplot (gamma-corrected data - model - GP) residuals = vel_corrected - (rv_mean_obs + rv_gp_mean_obs) # Create figure with subplots fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 5), gridspec_kw={'height_ratios': [3, 1], 'hspace': 0}) # Set up per-instrument colours colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)} # Plot observed data with error bars (gamma-corrected, coloured by instrument) for inst in self.unique_instruments: mask = (self.instrument == inst) ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=self.velerr[mask], marker=".", color=inst_colors[inst], ecolor=inst_colors[inst], linestyle="None", markersize=8, zorder=6, label=inst) # Jitter extension (faded, no label) ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=velerr_with_jit[mask], marker="None", ecolor=inst_colors[inst], linestyle="None", alpha=0.5, zorder=5) # Plot mean model component ax1.plot(tsmooth, rv_mean_smooth, label="Mean Model (Planets + Trend)", color="red", zorder=3) # Plot GP component # (We add the trend component so that the planet and GP signals start at the same baseline/gamma RV) ax1.plot(tsmooth, rv_gp_mean_smooth+rv_trend_smooth, color="blue", label='GP Component', zorder=2) # Plot total model (Planet + GP) ax1.plot(tsmooth, rv_total_smooth, label='Total Model (Planet + Trend + GP)', color="black", zorder=4) # Add 1-sigma uncertainty band around total model ax1.fill_between(tsmooth, rv_total_smooth - rv_gp_std_smooth, rv_total_smooth + rv_gp_std_smooth, color='darkgrey', zorder=1) ax1.set_xlim(tsmooth[0], tsmooth[-1]) if ylabel_main: ax1.set_ylabel(ylabel_main) if title: ax1.set_title(title) legend = ax1.legend(loc="upper right") legend.set_zorder(7) ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in') # Remove x-axis labels from top plot ax1.tick_params(axis='y', direction='in') # Set y-axis ticks automatically based on data range ax1.yaxis.set_major_locator(AutoLocator()) ax1.yaxis.set_minor_locator(AutoMinorLocator()) ax1.tick_params(axis='y', which='minor', direction='in', length=3) # Residuals subplot (coloured by instrument) for inst in self.unique_instruments: mask = (self.instrument == inst) ax2.errorbar(self.time[mask], residuals[mask], yerr=self.velerr[mask], marker=".", color=inst_colors[inst], ecolor=inst_colors[inst], linestyle="None", markersize=8, zorder=6) ax2.errorbar(self.time[mask], residuals[mask], yerr=velerr_with_jit[mask], marker="None", ecolor=inst_colors[inst], linestyle="None", alpha=0.5, zorder=5) ax2.axhline(0, color="black", linestyle="--", zorder=2) ax2.set_xlim(tsmooth[0], tsmooth[-1]) # Set symmetric y-limits for residuals max_abs_residual = np.max(np.abs(residuals + velerr_with_jit)) ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1) if xlabel: ax2.set_xlabel(xlabel) if ylabel_residuals: ax2.set_ylabel(ylabel_residuals) ax2.tick_params(axis='x', direction='in') ax2.tick_params(axis='y', direction='in') ax2.tick_params(axis='x', top=True, labeltop=False) # Add ticks on shared border # Set y-axis ticks automatically based on residuals range ax2.yaxis.set_major_locator(AutoLocator()) ax2.yaxis.set_minor_locator(AutoMinorLocator()) ax2.tick_params(axis='y', which='minor', direction='in', length=3) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def _plot_phase(self, planet_letter: str, params_hyperparams: Dict[str, float], title: str = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "phase_plot.png", dpi: int = 100) -> None: """Helper function to plot phase-folded RV model for a single planet with given parameters. For GP fitting, this handles the challenge that the GP component cannot be easily separated per planet. The approach is to show the target planet signal after subtracting other planets and trends, but keeping the GP component as part of the "data" being phase-folded. Parameters ---------- planet_letter : str Letter identifying the planet to plot (e.g., 'b', 'c', 'd') params : dict Dictionary of parameter values (both free and fixed) title : str, optional Plot title (default: f"Planet {planet_letter} Phase Plot"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "phase_plot.png") dpi : int, optional Resolution for saving (default: 100) """ if title is None: title = f"Planet {planet_letter} Phase Plot" # get smooth linear time curve for plotting _tmin, _tmax = self.time.min(), self.time.max() _trange = _tmax - _tmin tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000) # Calculate per-instrument jitter for error bars velerr_with_jit = np.zeros_like(self.velerr) for inst in self.unique_instruments: mask = (self.instrument == inst) jit = params_hyperparams[f"jit_{inst}"] velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit**2) # Get period and time of conjunction for this planet P = params_hyperparams[f"P_{planet_letter}"] # Convert to tc if needed if "Tc" in self.parameterisation.pars: Tc = params_hyperparams[f"Tc_{planet_letter}"] elif "e" in self.parameterisation.pars and "w" in self.parameterisation.pars: _e = params_hyperparams[f"e_{planet_letter}"] _w = params_hyperparams[f"w_{planet_letter}"] _Tp = params_hyperparams[f"Tp_{planet_letter}"] Tc = self.parameterisation.convert_tp_to_tc(_Tp, P, _e, _w) else: # Fall back to default parameterisation conversion planet_params = {par: params_hyperparams[f"{par}_{planet_letter}"] for par in self.parameterisation.pars} default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params) Tc = self.parameterisation.convert_tp_to_tc(default_params["Tp"], P, default_params["e"], default_params["w"]) # Create smooth time curve for plotting the model _tmin, _tmax = self.time.min(), self.time.max() _trange = _tmax - _tmin tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, 1000) # Phase fold both data and smooth times t_fold_sorted, inds = ravest.model.fold_time_series(self.time, P, Tc) tsmooth_fold_sorted, smooth_inds = ravest.model.fold_time_series(tsmooth, P, Tc) # Calculate RV contribution from this planet only (for model curve) planet_params = {par: params_hyperparams[f"{par}_{planet_letter}"] for par in self.parameterisation.pars} planet = ravest.model.Planet(planet_letter, self.parameterisation, planet_params) planet_rv_obs = planet.radial_velocity(self.time) planet_rv_smooth = planet.radial_velocity(tsmooth) planet_rv_smooth_sorted = planet_rv_smooth[smooth_inds] # Subtract per-instrument gamma offsets from data vel_corrected = self.vel.copy() for inst in self.unique_instruments: mask = (self.instrument == inst) vel_corrected[mask] -= params_hyperparams[f"g_{inst}"] # Now we need to calculate all the other RV, that needs to be subtracted from the observed data # All others planets + system Trend + GP mean other_rv_obs = np.zeros(len(self.time)) # Calculate all other Planets at observed times for other_letter in self.planet_letters: if other_letter != planet_letter: other_params = {par: params_hyperparams[f"{par}_{other_letter}"] for par in self.parameterisation.pars} other_planet = ravest.model.Planet(other_letter, self.parameterisation, other_params) other_rv_obs += other_planet.radial_velocity(self.time) # Calculate trend (gd, gdd only - no gamma) trend_params = {"gd": params_hyperparams["gd"], "gdd": params_hyperparams["gdd"]} trend = ravest.model.Trend(params=trend_params, t0=self.t0) other_rv_obs += trend.radial_velocity(self.time) # The mean model for the GP is all of the planets + the trend rv_mean_obs = other_rv_obs + planet_rv_obs # Set up GP for prediction hyperparams = {hp: params_hyperparams[hp] for hp in self.gp_kernel.expected_hyperparams} kernel = self.gp_kernel.build_kernel(hyperparams) # Calculate per-instrument jitter for GP diagonal jit2_verr2 = np.zeros(len(self.time)) for inst in self.unique_instruments: mask = (self.instrument == inst) jit = params_hyperparams[f"jit_{inst}"] jit2_verr2[mask] = self.velerr[mask]**2 + jit**2 # Create GP conditioned on gamma-corrected residuals residuals_obs = vel_corrected - rv_mean_obs gp_obs = GaussianProcess(kernel=kernel, X=jnp.array(self.time), diag=jnp.array(jit2_verr2)) _, gp_cond = gp_obs.condition(y=jnp.array(residuals_obs), X_test=jnp.array(self.time)) gp_mean_obs = np.array(gp_cond.mean) # Calculate data points for the plot: gamma-corrected data - (other planets + Trend + GP mean) data_minus_others_obs = vel_corrected - (other_rv_obs + gp_mean_obs) # Sort the data according to phase folding data_minus_others_obs_sorted = data_minus_others_obs[inds] verr_sorted = self.velerr[inds] velerr_with_jit_sorted = velerr_with_jit[inds] instrument_sorted = self.instrument[inds] # Create figure with subplots (main plot + residuals) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 5), gridspec_kw={'height_ratios': [3, 1], 'hspace': 0}) # Set up per-instrument colours colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)} # Main phase plot (coloured by instrument) for inst in self.unique_instruments: mask = (instrument_sorted == inst) ax1.errorbar(t_fold_sorted[mask], data_minus_others_obs_sorted[mask], yerr=verr_sorted[mask], marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst], markersize=8, zorder=4, label=inst) # Jitter extension (faded, no label) ax1.errorbar(t_fold_sorted[mask], data_minus_others_obs_sorted[mask], yerr=velerr_with_jit_sorted[mask], marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3) # Plot phase-folded model for this planet ax1.plot(tsmooth_fold_sorted, planet_rv_smooth_sorted, label="Planet Model", color="black", zorder=2) ax1.set_xlim(-0.5, 0.5) ax1.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25 if ylabel_main: ax1.set_ylabel(ylabel_main) ax1.legend(loc="upper right") if title: ax1.set_title(title) ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in') ax1.tick_params(axis='y', direction='in') # Set y-axis ticks automatically based on phase data range ax1.yaxis.set_major_locator(AutoLocator()) ax1.yaxis.set_minor_locator(AutoMinorLocator()) ax1.tick_params(axis='y', which='minor', direction='in', length=3) # Annotate with planet info K_value = params_hyperparams[f"K_{planet_letter}"] P_label = param_key_to_latex(f"P_{planet_letter}") K_label = param_key_to_latex(f"K_{planet_letter}") s = f"Planet {planet_letter}\n{P_label}={P:.2f} d\n{K_label}={K_value:.2f} m/s" ax1.annotate(s, xy=(0, 1), xycoords="axes fraction", xytext=(+0.5, -0.5), textcoords="offset fontsize", va="top") # Residuals plot (phase-folded, coloured by instrument) residuals = data_minus_others_obs - planet_rv_obs residuals_sorted = residuals[inds] for inst in self.unique_instruments: mask = (instrument_sorted == inst) ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=verr_sorted[mask], marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst], markersize=8, zorder=4) ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=velerr_with_jit_sorted[mask], marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3) ax2.axhline(0, color="k", linestyle="--", zorder=2) ax2.set_xlim(-0.5, 0.5) ax2.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25 # # Set symmetric y-limits for residuals # max_abs_residual = np.max(np.abs(residuals + velerr_with_jit)) # ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1) if xlabel: ax2.set_xlabel(xlabel) if ylabel_residuals: ax2.set_ylabel(ylabel_residuals) ax2.tick_params(axis='x', direction='in') ax2.tick_params(axis='y', direction='in') ax2.tick_params(axis='x', top=True, labeltop=False) # Set y-axis ticks automatically based on residuals range ax2.yaxis.set_major_locator(AutoLocator()) ax2.yaxis.set_minor_locator(AutoMinorLocator()) ax2.tick_params(axis='y', which='minor', direction='in', length=3) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def plot_posterior_rv(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, show_CI: bool = True, title: str | None = "Posterior predictions (with GP)", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "posterior_rv.png", dpi: int = 100, n_smooth: int = 1000) -> None: """Plot the posterior GP RV model with uncertainty bands from MCMC samples. Calculates RV model predictions for each MCMC sample, then plots the median with optional 68% CI (16th-84th percentile) uncertainty bands. Shows both the full model (planetary signals + trend + GP) and residuals vs data. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) show_CI : bool, optional Show 68.3% credible interval band (default: True) title : str or None, optional Title for the main RV plot (default: "Posterior predictions (with GP)"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot to path `fname` (default: False) fname : str, optional The path to save the plot to (default: "posterior_rv.png") dpi : int, optional The dpi to save the image at (default: 100) n_smooth : int, optional Number of points in smooth time grid for plotting model curves (default: 1000). Reduce for faster plotting, increase for smoother curves. """ # Create smooth time curve for plotting _tmin, _tmax = self.time.min(), self.time.max() _trange = _tmax - _tmin tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, n_smooth) # Get samples for free parameters/hyperparameters and combine with fixed values samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin) # Combine with fixed parameters and fixed hyperparameters params_hyperparams = samples_dict | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict # Report number of effective samples n_samples = len(samples_dict[list(samples_dict.keys())[0]]) print(f"Processing {n_samples} effective samples (after discard_start={discard_start}, discard_end={discard_end}, thin={thin})") # Calculate all planets + trend RVs for both obs and smooth times rv_all_planets_trend_matrix_obs = np.zeros((len(samples_dict[list(samples_dict.keys())[0]]), len(self.time))) rv_all_planets_trend_matrix_smooth = np.zeros((len(samples_dict[list(samples_dict.keys())[0]]), len(tsmooth))) # Add all planets for planet_letter in self.planet_letters: print(f"Calculating planet {planet_letter} RV at {len(self.time)} observed times...") rv_all_planets_trend_matrix_obs += self.calculate_rv_planet_from_samples(planet_letter, self.time, discard_start, discard_end, thin) print(f"Calculating planet {planet_letter} RV at {len(tsmooth)} smooth times...") rv_all_planets_trend_matrix_smooth += self.calculate_rv_planet_from_samples(planet_letter, tsmooth, discard_start, discard_end, thin) # Add trend print(f"Calculating trend RV at {len(self.time)} observed times...") rv_all_planets_trend_matrix_obs += self.calculate_rv_trend_from_samples(self.time, discard_start, discard_end, thin) print(f"Calculating trend RV at {len(tsmooth)} smooth times...") rv_all_planets_trend_matrix_smooth += self.calculate_rv_trend_from_samples(tsmooth, discard_start, discard_end, thin) # Subtract per-instrument gamma offsets from data (using median gamma values) vel_corrected = self.vel.copy() for inst in self.unique_instruments: mask = (self.instrument == inst) g_key = f"g_{inst}" if isinstance(params_hyperparams[g_key], np.ndarray): vel_corrected[mask] -= np.median(params_hyperparams[g_key]) else: vel_corrected[mask] -= params_hyperparams[g_key] # Calculate GP mean at obs times and smooth times (conditioned on gamma-corrected residuals) residuals_matrix_obs = vel_corrected - rv_all_planets_trend_matrix_obs gp_mean_matrix_obs = np.zeros_like(residuals_matrix_obs) gp_mean_matrix_smooth = np.zeros((residuals_matrix_obs.shape[0], len(tsmooth))) for i in tqdm(range(residuals_matrix_obs.shape[0]), desc="Computing GP predictions"): # Extract hyperparameters for this sample sample_hyperparams = {} for hp in self.gp_kernel.expected_hyperparams: if isinstance(params_hyperparams[hp], np.ndarray): sample_hyperparams[hp] = params_hyperparams[hp][i] else: sample_hyperparams[hp] = params_hyperparams[hp] # Build GP kernel for this sample kernel = self.gp_kernel.build_kernel(sample_hyperparams) # Get per-instrument jitter values for this sample jit2_verr2 = np.zeros(len(self.time)) for inst in self.unique_instruments: mask = (self.instrument == inst) jit_key = f"jit_{inst}" if isinstance(params_hyperparams[jit_key], np.ndarray): jit_value = params_hyperparams[jit_key][i] else: jit_value = params_hyperparams[jit_key] jit2_verr2[mask] = self.velerr[mask]**2 + jit_value**2 # Create GP once for this sample gp_obs = GaussianProcess(kernel=kernel, X=jnp.array(self.time), diag=jnp.array(jit2_verr2)) # Predict at observation times _, gp_cond_obs = gp_obs.condition(y=jnp.array(residuals_matrix_obs[i]), X_test=jnp.array(self.time)) gp_mean_matrix_obs[i] = np.array(gp_cond_obs.mean) # Predict at smooth times (reusing the same GP object) _, gp_cond_smooth = gp_obs.condition(y=jnp.array(residuals_matrix_obs[i]), X_test=jnp.array(tsmooth)) gp_mean_matrix_smooth[i] = np.array(gp_cond_smooth.mean) # Combine full model: planets + trend + GP rv_full_matrix_obs = rv_all_planets_trend_matrix_obs + gp_mean_matrix_obs rv_full_matrix_smooth = rv_all_planets_trend_matrix_smooth + gp_mean_matrix_smooth # Calculate percentiles rv_percentiles_smooth = np.percentile(rv_full_matrix_smooth, [15.85, 50, 84.15], axis=0) rv_percentiles_obs = np.percentile(rv_full_matrix_obs, [15.85, 50, 84.15], axis=0) # Calculate residuals using median model at data times (gamma-corrected data) residuals = vel_corrected - rv_percentiles_obs[1] # Get per-instrument jitter for error bars (use median from samples) velerr_with_jit = np.zeros_like(self.velerr) for inst in self.unique_instruments: mask = (self.instrument == inst) jit_key = f"jit_{inst}" if isinstance(params_hyperparams[jit_key], np.ndarray): jit_median = np.median(params_hyperparams[jit_key]) else: jit_median = params_hyperparams[jit_key] velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit_median**2) # Create figure with subplots (same layout as Fitter) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 5), gridspec_kw={'height_ratios': [3, 1], 'hspace': 0}) # Set up per-instrument colours colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)} # Main RV plot (gamma-corrected data, coloured by instrument) for inst in self.unique_instruments: mask = (self.instrument == inst) ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=self.velerr[mask], marker=".", color=inst_colors[inst], ecolor=inst_colors[inst], linestyle="None", markersize=8, zorder=4, label=inst) # Jitter extension (faded, no label) ax1.errorbar(self.time[mask], vel_corrected[mask], yerr=velerr_with_jit[mask], marker="None", ecolor=inst_colors[inst], linestyle="None", alpha=0.5, zorder=3) # Plot median model and uncertainty ax1.plot(tsmooth, rv_percentiles_smooth[1], label="Model (Mean + GP)", color="black", zorder=2) if show_CI: ax1.fill_between(tsmooth, rv_percentiles_smooth[0], rv_percentiles_smooth[2], color="tab:gray", alpha=0.3, edgecolor="none", label="68.3% CI", zorder=1) ax1.set_xlim(tsmooth[0], tsmooth[-1]) if ylabel_main: ax1.set_ylabel(ylabel_main) if title: ax1.set_title(title) ax1.legend(loc="upper right") ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in') ax1.tick_params(axis='y', direction='in') # Set y-axis ticks ax1.yaxis.set_major_locator(AutoLocator()) ax1.yaxis.set_minor_locator(AutoMinorLocator()) ax1.tick_params(axis='y', which='minor', direction='in', length=3) # Residuals plot (coloured by instrument) for inst in self.unique_instruments: mask = (self.instrument == inst) ax2.errorbar(self.time[mask], residuals[mask], yerr=self.velerr[mask], marker=".", color=inst_colors[inst], ecolor=inst_colors[inst], linestyle="None", markersize=8, zorder=4) ax2.errorbar(self.time[mask], residuals[mask], yerr=velerr_with_jit[mask], marker="None", ecolor=inst_colors[inst], linestyle="None", alpha=0.5, zorder=3) ax2.axhline(0, color="k", linestyle="--", zorder=2) ax2.set_xlim(tsmooth[0], tsmooth[-1]) # Set symmetric y-limits for residuals plot max_abs_residual = np.max(np.abs(residuals + velerr_with_jit)) ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1) if xlabel: ax2.set_xlabel(xlabel) if ylabel_residuals: ax2.set_ylabel(ylabel_residuals) ax2.tick_params(axis='x', direction='in') ax2.tick_params(axis='y', direction='in') ax2.tick_params(axis='x', top=True, labeltop=False) # Set y-axis ticks ax2.yaxis.set_major_locator(AutoLocator()) ax2.yaxis.set_minor_locator(AutoMinorLocator()) ax2.tick_params(axis='y', which='minor', direction='in', length=3) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def plot_posterior_phase(self, planet_letter: str, discard_start: int = 0, discard_end: int = 0, thin: int = 1, show_CI: bool = True, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "posterior_phase.png", dpi: int = 100, n_smooth: int = 50) -> None: """Plot phase-folded GP RV model with uncertainty bands from MCMC samples. Calculates phase-folded planetary signal with uncertainty bands calculated from MCMC samples. Shows the target planet signal after removing trends and other planets, with the GP component included in the data being plotted. Parameters ---------- planet_letter : str Letter identifying the planet to plot (e.g., 'b', 'c', 'd') discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) show_CI : bool, optional Show 68.3% credible interval band (default: True) title : str or None, optional Title for the main phase plot (default: "Posterior Phase Plot (with GP) - Planet {planet_letter}"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot to path `fname` (default: False) fname : str, optional The path to save the plot to (default: "posterior_phase.png") dpi : int, optional The dpi to save the image at (default: 100) n_smooth : int, optional Number of points in smooth time grid for plotting model curves (default: 50). Reduce for faster plotting, increase for smoother curves. """ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 5), gridspec_kw={'height_ratios': [3, 1], 'hspace': 0}) # Get samples for free parameters/hyperparameters and combine with fixed values samples_dict = self.get_samples_dict(discard_start=discard_start, discard_end=discard_end, thin=thin) # Combine with fixed parameters and fixed hyperparameters params_hyperparams = samples_dict | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict # Report number of effective samples n_samples = len(samples_dict[list(samples_dict.keys())[0]]) print(f"Processing {n_samples} effective samples (after discard_start={discard_start}, discard_end={discard_end}, thin={thin})") # Create smooth time array for plotting the model _tmin, _tmax = self.time.min(), self.time.max() _trange = _tmax - _tmin tsmooth = np.linspace(_tmin - 0.01 * _trange, _tmax + 0.01 * _trange, n_smooth) # Get per-instrument jitter for error bars (use median from samples) velerr_with_jit = np.zeros_like(self.velerr) for inst in self.unique_instruments: mask = (self.instrument == inst) jit_key = f"jit_{inst}" if isinstance(params_hyperparams[jit_key], np.ndarray): jit_median = np.median(params_hyperparams[jit_key]) else: jit_median = params_hyperparams[jit_key] velerr_with_jit[mask] = np.sqrt(self.velerr[mask]**2 + jit_median**2) # Subtract per-instrument gamma offsets from data (using median gamma values) vel_corrected = self.vel.copy() for inst in self.unique_instruments: mask = (self.instrument == inst) g_key = f"g_{inst}" if isinstance(params_hyperparams[g_key], np.ndarray): vel_corrected[mask] -= np.median(params_hyperparams[g_key]) else: vel_corrected[mask] -= params_hyperparams[g_key] # Get period value _P = params_hyperparams[f'P_{planet_letter}'] # Get (or calculate) Tc for this planet for folding around if "Tc" in self.parameterisation.pars: _Tc = params_hyperparams[f"Tc_{planet_letter}"] else: # Fall back to default parameterisation conversion (as it has P, e, w and Tp, so we can definitely get Tc) planet_params = {par: params_hyperparams[f"{par}_{planet_letter}"] for par in self.parameterisation.pars} default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params) _Tc = self.parameterisation.convert_tp_to_tc(default_params["Tp"], _P, default_params["e"], default_params["w"]) # just for the folding, take the median value of the P and Tc samples Tc = np.median(_Tc) P = np.median(_P) # Phase fold both data and smooth times t_fold_sorted, inds = ravest.model.fold_time_series(self.time, P, Tc) tsmooth_fold_sorted, smooth_inds = ravest.model.fold_time_series(tsmooth, P, Tc) # So, what do we want to do? # We want to: # For each sample: # 1) Calculate this planet's RVs at obs times and at smooth times # 2) Calculate all the other planets' RVs at obs times # 3) Calculate trend at obs times # 4) Calculate GP mean at obs times (conditioned on residuals, i.e. subtract combination of planets + trend). This will give us an array of n_samples x n_obs. # 5) Add together the matrix of (other planets + trend + GP), we get n_samples by n_obs matrix # 6) Take the median of this matrix, this gives us n_obs RV measurements # 7) subtract this from the observed data # 8) plot: that's our datapoints for this plot. # 9) overplot: (also, overplot the errorbars caused by adding (median jitter) in quadrature too) # 10) Planet RV model curve: # for each sample: take planetary params, calculate RV (so we get matrix n_samples by n_obs) # take percentiles of this matrix, plot median and 16th-84th percentile band # 11) Residuals panel # subtract planet (only) RV obs, from the matrix data - (other planets + trend + GP mean) # 1) Calculate this planet's RVs at obs times and at smooth times print(f"Calculating planet {planet_letter} RV at {len(self.time)} observed times...") rv_this_planet_matrix_obs = self.calculate_rv_planet_from_samples(planet_letter, self.time, discard_start, discard_end, thin) print(f"Calculating planet {planet_letter} RV at {len(tsmooth)} smooth times...") rv_this_planet_matrix_smooth = self.calculate_rv_planet_from_samples(planet_letter, tsmooth, discard_start, discard_end, thin) # 2) Calculate all the other planets' RVs at obs times rv_other_planets_matrix_obs = np.zeros((rv_this_planet_matrix_obs.shape[0], len(self.time))) for other_letter in self.planet_letters: if other_letter != planet_letter: print(f"Calculating planet {other_letter} RV at {len(self.time)} observed times...") rv_other_planets_matrix_obs += self.calculate_rv_planet_from_samples(other_letter, self.time, discard_start, discard_end, thin) # 3) Calculate trend at obs times print(f"Calculating trend RV at {len(self.time)} observed times...") rv_trend_matrix_obs = self.calculate_rv_trend_from_samples(self.time, discard_start, discard_end, thin) # 4) Calculate GP mean at obs times (conditioned on residuals, i.e. subtract combination of planets + trend) # First, we need to combine all the planets + trend rv_all_planets_trend_matrix_obs = rv_this_planet_matrix_obs + rv_other_planets_matrix_obs + rv_trend_matrix_obs # Second, we make a matrix of all of the residuals (based on the planetary+trends RVs from each sample) # Using gamma-corrected data residuals_matrix_obs = vel_corrected - rv_all_planets_trend_matrix_obs # Third, we can now go over the samples, set up the GP with the hyperparameters # for each sample, give the residuals from each planetary+trend parameter sample, # and calculate the GP mean. Getting another matrix n_samples by n_obs gp_mean_matrix_obs = np.zeros_like(residuals_matrix_obs) # Loop over samples for i in tqdm(range(residuals_matrix_obs.shape[0])): # Extract hyperparameters for this sample sample_hyperparams = {} for hp in self.gp_kernel.expected_hyperparams: if isinstance(params_hyperparams[hp], np.ndarray): # Free hyperparameter - take i-th sample sample_hyperparams[hp] = params_hyperparams[hp][i] else: # Fixed hyperparameter - use scalar value sample_hyperparams[hp] = params_hyperparams[hp] # Build GP kernel for this sample kernel = self.gp_kernel.build_kernel(sample_hyperparams) # Get per-instrument jitter values for this sample jit2_verr2 = np.zeros(len(self.time)) for inst in self.unique_instruments: mask = (self.instrument == inst) jit_key = f"jit_{inst}" if isinstance(params_hyperparams[jit_key], np.ndarray): jit_value = params_hyperparams[jit_key][i] else: jit_value = params_hyperparams[jit_key] jit2_verr2[mask] = self.velerr[mask]**2 + jit_value**2 # Create GP conditioned on residuals for this sample gp_obs = GaussianProcess(kernel=kernel, X=jnp.array(self.time), diag=jnp.array(jit2_verr2)) _, gp_cond = gp_obs.condition(y=jnp.array(residuals_matrix_obs[i]), X_test=jnp.array(self.time)) gp_mean_matrix_obs[i] = np.array(gp_cond.mean) # 5) Add together the matrix of (other planets + trend + GP), we get n_samples by n_obs matrix rv_all_other_matrix_obs = rv_other_planets_matrix_obs + rv_trend_matrix_obs + gp_mean_matrix_obs # 6) Take the median of this matrix, this gives us n_obs RV measurements rv_all_other_median_obs = np.median(rv_all_other_matrix_obs, axis=0) # 7) subtract this from the gamma-corrected data data_minus_others = vel_corrected - rv_all_other_median_obs # single array # 8) Plot: that's our datapoints for this plot. # First: phase-fold data_minus_others using the indices we made earlier data_minus_others_folded = data_minus_others[inds] instrument_sorted = self.instrument[inds] verr_sorted = self.velerr[inds] velerr_with_jit_sorted = velerr_with_jit[inds] # Set up per-instrument colours colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] inst_colors = {inst: colors[i % len(colors)] for i, inst in enumerate(self.unique_instruments)} # Now we can plot the data (coloured by instrument) for inst in self.unique_instruments: mask = (instrument_sorted == inst) ax1.errorbar(t_fold_sorted[mask], data_minus_others_folded[mask], yerr=verr_sorted[mask], marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst], markersize=8, zorder=4, label=inst) # 9) overplot the errorbars with jitter (faded, no label) ax1.errorbar(t_fold_sorted[mask], data_minus_others_folded[mask], yerr=velerr_with_jit_sorted[mask], marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3) # 10) Planet RV model curve: # for each sample: take planetary params, calculate RV (so we get matrix n_samples by n_obs) # take percentiles of this matrix, plot median and (optionaly) the 68.3% interval band rv_planet_smooth_percs = np.percentile(rv_this_planet_matrix_smooth, [15.85, 50, 84.15], axis=0) ax1.plot(tsmooth_fold_sorted, rv_planet_smooth_percs[1][smooth_inds], color="black", label="Planet Model", zorder=2) if show_CI: ax1.fill_between(tsmooth_fold_sorted, rv_planet_smooth_percs[0][smooth_inds], rv_planet_smooth_percs[2][smooth_inds], color="tab:gray", alpha=0.3, edgecolor="none", label="68.3% CI", zorder=1) # 11) Residuals panel residuals = data_minus_others - np.median(rv_this_planet_matrix_obs, axis=0) residuals_sorted = residuals[inds] ax2.axhline(0, color="k", linestyle="--", zorder=2) for inst in self.unique_instruments: mask = (instrument_sorted == inst) ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=verr_sorted[mask], marker=".", linestyle="None", color=inst_colors[inst], ecolor=inst_colors[inst], markersize=8, zorder=4) ax2.errorbar(t_fold_sorted[mask], residuals_sorted[mask], yerr=velerr_with_jit_sorted[mask], marker="None", linestyle="None", ecolor=inst_colors[inst], alpha=0.5, zorder=3) ax1.set_xlim(-0.5, 0.5) ax1.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25 if ylabel_main: ax1.set_ylabel(ylabel_main) if title is None: ax1.set_title(f"Posterior Phase Plot (with GP) - Planet {planet_letter}") elif title: ax1.set_title(title) ax1.legend(loc="upper right") ax1.tick_params(axis='x', labelbottom=False, bottom=True, top=False, direction='in') ax1.tick_params(axis='y', direction='in') ax1.yaxis.set_major_locator(AutoLocator()) ax1.yaxis.set_minor_locator(AutoMinorLocator()) ax1.tick_params(axis='y', which='minor', direction='in', length=3) ax2.set_xlim(-0.5, 0.5) ax2.xaxis.set_major_locator(MultipleLocator(0.25)) # Set x-ticks every 0.25 # Set symmetric y-limits for residuals plot max_abs_residual = np.max(np.abs(residuals_sorted + velerr_with_jit_sorted)) ax2.set_ylim(-max_abs_residual * 1.1, max_abs_residual * 1.1) if xlabel: ax2.set_xlabel(xlabel) if ylabel_residuals: ax2.set_ylabel(ylabel_residuals) ax2.tick_params(axis='x', direction='in') ax2.tick_params(axis='y', direction='in') ax2.tick_params(axis='x', top=True, labeltop=False) # Set y-axis ticks ax2.yaxis.set_major_locator(AutoLocator()) ax2.yaxis.set_minor_locator(AutoMinorLocator()) ax2.tick_params(axis='y', which='minor', direction='in', length=3) if save: plt.savefig(fname=fname, dpi=dpi) print(f"Saved {fname}") plt.show()
[docs] def plot_MAP_rv(self, map_result: scipy.optimize.OptimizeResult, title: str | None = "MAP RV (with GP)", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "MAP_rv.png", dpi: int = 100) -> None: """Plot the MAP RV model. Uses the Maximum A Posteriori (MAP) parameter estimates to plot the GP model including both mean model (planets + trend) and GP component. Parameters ---------- map_result : scipy.optimize.OptimizeResult Result from find_map_estimate() title : str or None, optional Plot title (default: "MAP RV (with GP)"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "MAP_rv.png") dpi : int, optional Resolution for saving (default: 100) """ # Get MAP parameter and hyperparameter values from the optimization result # map_result.x contains both free parameters and free hyperparameters all_free_names = self.free_params_names + self.free_hyperparams_names map_params_and_hyperparams = dict(zip(all_free_names, map_result.x)) # Combine with fixed parameters and fixed hyperparameters all_params = map_params_and_hyperparams | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict # Use helper function to create the plot self._plot_rv(all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] def plot_MAP_phase(self, planet_letter: str, map_result: scipy.optimize.OptimizeResult, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "MAP_phase.png", dpi: int = 100) -> None: """Plot the MAP phase model. Uses the Maximum A Posteriori (MAP) parameter estimates to plot the phase-folded GP model for a specific planet, including both mean model (planets + trend) and GP component. Parameters ---------- planet_letter : str Letter identifying the planet to plot (e.g., 'b', 'c', 'd') map_result : scipy.optimize.OptimizeResult Result from find_map_estimate() title : str or None, optional Plot title (default: f"MAP Phase Plot (with GP) - Planet {planet_letter}"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "MAP_phase.png") dpi : int, optional Resolution for saving (default: 100) """ # Get MAP parameter and hyperparameter values from the optimization result # map_result.x contains both free parameters and free hyperparameters all_free_names = self.free_params_names + self.free_hyperparams_names map_params_and_hyperparams = dict(zip(all_free_names, map_result.x)) # Combine with fixed parameters and fixed hyperparameters all_params = map_params_and_hyperparams | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict # Set default title if not provided if title is None: title = f"MAP Phase Plot (with GP) - Planet {planet_letter}" # Use helper function to create the plot self._plot_phase(planet_letter, all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] def plot_custom_rv(self, params_hyperparams: dict, title: str | None = "Custom GP RV Plot", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "custom_rv.png", dpi: int = 100) -> None: """Plot GP radial velocity data and model using custom parameter and hyperparameter values. Allows plotting with arbitrary parameter and hyperparameter values for exploring parameter space or comparing theoretical models. The GP component will be conditioned on the residuals from the mean model (planets + trend). Parameters ---------- params_hyperparams : dict Dictionary of parameter and hyperparameter values to use for plotting. Keys should match parameter and hyperparameter names, values should be floats. Must include all required parameters and hyperparameters for the current parameterisation and GP kernel. title : str or None, optional Plot title (default: "Custom GP RV Plot"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "custom_rv.png") dpi : int, optional Resolution for saving (default: 100) Examples -------- >>> # Plot with custom values (must include all required parameters + hyperparameters) >>> gpfitter.plot_custom_rv({"P_b": 4.25, "K_b": 55.0, "e_b": 0.1, ... "w_b": 1.57, "Tc_b": 2456325.5, ... "g": -10.2, "gd": 0.0, "gdd": 0.0, "jit": 2.0, ... "gp_amp": 15.0, "gp_lambda_e": 50.0, ... "gp_lambda_p": 0.5, "gp_period": 25.0}) """ # Validate that all required parameters and hyperparameters are present expected_params = set(self.free_params_names + list(self.fixed_params_names)) expected_hyperparams = set(self.free_hyperparams_names + list(self.fixed_hyperparams_names)) expected_all = expected_params | expected_hyperparams provided_params = set(params_hyperparams.keys()) missing_params = expected_all - provided_params if missing_params: raise ValueError(f"Missing required parameters/hyperparameters: {missing_params}") # Use helper function to create the plot self._plot_rv(params_hyperparams, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] def plot_custom_phase(self, planet_letter: str, params_hyperparams: dict, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "custom_phase.png", dpi: int = 100) -> None: """Plot GP phase-folded radial velocity data and model using custom parameter and hyperparameter values. Allows plotting phase-folded data with arbitrary parameter and hyperparameter values for exploring parameter space or comparing theoretical models. The GP component will be conditioned on the residuals from the mean model (planets + trend). Parameters ---------- planet_letter : str Letter identifying the planet to plot (e.g., 'b', 'c', 'd') params_hyperparams : dict Dictionary of parameter and hyperparameter values to use for plotting. Keys should match parameter and hyperparameter names, values should be floats. Must include all required parameters and hyperparameters for the current parameterisation and GP kernel. title : str or None, optional Plot title (default: f"Custom GP Phase Plot - Planet {planet_letter}"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "custom_phase.png") dpi : int, optional Resolution for saving (default: 100) Examples -------- >>> # Plot phase curve with custom values >>> gpfitter.plot_custom_phase("b", {"P_b": 4.25, "K_b": 55.0, "e_b": 0.1, ... "w_b": 1.57, "Tc_b": 2456325.5, ... "g": -10.2, "gd": 0.0, "gdd": 0.0, "jit": 2.0, ... "gp_amp": 15.0, "gp_lambda_e": 50.0, ... "gp_lambda_p": 0.5, "gp_period": 25.0}) """ # Validate that all required parameters and hyperparameters are present expected_params = set(self.free_params_names + list(self.fixed_params_names)) expected_hyperparams = set(self.free_hyperparams_names + list(self.fixed_hyperparams_names)) expected_all = expected_params | expected_hyperparams provided_params = set(params_hyperparams.keys()) missing_params = expected_all - provided_params if missing_params: raise ValueError(f"Missing required parameters/hyperparameters: {missing_params}") # Set default title if not provided if title is None: title = f"Custom GP Phase Plot - Planet {planet_letter}" # Use helper function to create the plot self._plot_phase(planet_letter, params_hyperparams, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] def plot_best_sample_rv(self, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = "Best Sample RV Plot (with GP)", ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Time [days]", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "best_sample_rv.png", dpi: int = 100) -> None: """Plot radial velocity data and model using parameter and hyperparameter values from the MCMC sample with highest log probability. This is useful for comparing with plot_MAP_rv() to diagnose potential issues with MAP convergence or MCMC mixing. The two plots should be very similar if both MAP and MCMC are working correctly. Parameters ---------- discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) title : str or None, optional Title for the main RV plot (default: "Best Sample RV Plot (with GP)"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main RV plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Time [days]"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "best_sample_rv.png") dpi : int, optional Resolution for saving (default: 100) """ # Get parameter and hyperparameter values from best sample best_sample_params_hyperparams = self.get_sample_with_best_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin) # Combine with fixed parameters and fixed hyperparameters all_params = best_sample_params_hyperparams | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict # Use helper function to create the plot self._plot_rv(all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] def plot_best_sample_phase(self, planet_letter: str, discard_start: int = 0, discard_end: int = 0, thin: int = 1, title: str | None = None, ylabel_main: str | None = "Radial velocity [m/s]", xlabel: str | None = "Orbital phase", ylabel_residuals: str | None = "Residuals [m/s]", save: bool = False, fname: str = "best_sample_phase.png", dpi: int = 100) -> None: """Plot phase-folded radial velocity data and model using parameter and hyperparameter values from the MCMC sample with highest log probability. This is useful for comparing with plot_MAP_phase() to diagnose potential issues with MAP convergence or MCMC mixing. The two plots should be very similar if both MAP and MCMC are working correctly. Parameters ---------- planet_letter : str Letter identifying the planet to plot (e.g., 'b', 'c', 'd') discard_start : int, optional Discard the first `discard_start` steps from the start of the chain (default: 0) discard_end : int, optional Discard the last `discard_end` steps from the end of the chain (default: 0) thin : int, optional Use only every `thin` steps from the chain (default: 1) title : str or None, optional Title for the main phase plot (default: "Best Sample Phase Plot (with GP) - Planet {planet_letter}"). Set to None or "" to skip. ylabel_main : str or None, optional Y-axis label for main phase plot (default: "Radial velocity [m/s]"). Set to None or "" to skip. xlabel : str or None, optional X-axis label for residuals plot (default: "Orbital phase"). Set to None or "" to skip. ylabel_residuals : str or None, optional Y-axis label for residuals plot (default: "Residuals [m/s]"). Set to None or "" to skip. save : bool, optional Save the plot (default: False) fname : str, optional Filename to save (default: "best_sample_phase.png") dpi : int, optional Resolution for saving (default: 100) """ # Get parameter and hyperparameter values from best sample best_sample_params_hyperparams = self.get_sample_with_best_lnprob(discard_start=discard_start, discard_end=discard_end, thin=thin) # Combine with fixed parameters and fixed hyperparameters all_params = best_sample_params_hyperparams | self.fixed_params_values_dict | self.fixed_hyperparams_values_dict # Set default title if not provided if title is None: title = f"Best Sample Phase Plot (with GP) - Planet {planet_letter}" # Use helper function to create the plot self._plot_phase(planet_letter, all_params, title=title, ylabel_main=ylabel_main, xlabel=xlabel, ylabel_residuals=ylabel_residuals, save=save, fname=fname, dpi=dpi)
[docs] def calculate_rv_planet_from_samples(self, planet_letter: str, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = True) -> np.ndarray: """Calculate planetary RV for each MCMC sample. This calculates RV(params_i) for each MCMC sample i, preserving parameter correlations. This differs from using median parameters which may not represent actual samples from the posterior. Parameters ---------- planet_letter : str Planet letter (e.g., 'b', 'c') times : np.ndarray Time points to calculate RV at discard_start : int, optional Discard first N steps (default: 0) discard_end : int, optional Discard last N steps (default: 0) thin : int, optional Use every Nth sample (default: 1) progress : bool, optional Show progress bar (default: True) Returns ------- np.ndarray Shape (n_samples, len(times)) - RV for each sample """ samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) planet_rvs = np.zeros((len(samples), len(times))) iterator = tqdm(enumerate(samples), total=len(samples), disable=not progress, desc=f"Calculating planet {planet_letter} RV from samples") for i, combined_sample in iterator: # Build complete params dict for this sample (includes hyperparams) params = self.build_params_dict(combined_sample) # Use custom method planet_rvs[i, :] = self.calculate_rv_planet_custom(planet_letter, times, params) return planet_rvs
[docs] def calculate_rv_trend_from_samples(self, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = True) -> np.ndarray: """Calculate trend RV for each MCMC sample. This calculates RV_trend(params_i) for each MCMC sample i, preserving parameter correlations. This differs from using median parameters which may not represent actual samples from the posterior. Parameters ---------- times : np.ndarray Time points to calculate RV at discard_start : int, optional Discard first N steps (default: 0) discard_end : int, optional Discard last N steps (default: 0) thin : int, optional Use every Nth sample (default: 1) progress : bool, optional Show progress bar (default: True) Returns ------- np.ndarray Shape (n_samples, len(times)) - Trend RV for each sample """ samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) trend_rvs = np.zeros((len(samples), len(times))) iterator = tqdm(enumerate(samples), total=len(samples), disable=not progress, desc="Calculating trend RV from samples") for i, combined_sample in iterator: # Build complete params dict for this sample (includes hyperparams) params = self.build_params_dict(combined_sample) # Use custom method trend_rvs[i, :] = self.calculate_rv_trend_custom(times, params) return trend_rvs
[docs] def calculate_rv_gp_from_samples(self, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = True) -> np.ndarray: """Calculate GP component for each MCMC sample. This calculates the GP mean at the specified times for each MCMC sample, preserving parameter and hyperparameter correlations from the posterior. The GP is conditioned on residuals from the observed RV data, while predictions are called at the requested `times`. Parameters ---------- times : np.ndarray Time points to calculate GP at discard_start : int, optional Discard first N steps (default: 0) discard_end : int, optional Discard last N steps (default: 0) thin : int, optional Use every Nth sample (default: 1) progress : bool, optional Show progress bar (default: True) Returns ------- np.ndarray Shape (n_samples, len(times)) - GP component for each sample """ samples = self.get_samples_np(discard_start=discard_start, discard_end=discard_end, thin=thin, flat=True) gp_components = np.zeros((len(samples), len(times))) # Pre-calculate mean RV (Trend + Planets) at data times for all samples # These are used to condition the GP mean_rv_at_data = self.calculate_rv_trend_from_samples(self.time, discard_start, discard_end, thin) for planet_letter in self.planet_letters: mean_rv_at_data += self.calculate_rv_planet_from_samples(planet_letter, self.time, discard_start, discard_end, thin) iterator = tqdm(enumerate(samples), total=len(samples), disable=not progress, desc="Calculating GP from samples") for i, combined_sample in iterator: # Build complete params dict for this sample (includes hyperparams) params = self.build_params_dict(combined_sample) # Use custom method gp_components[i, :] = self.calculate_rv_gp_custom(times, params) return gp_components
[docs] def calculate_rv_total_from_samples(self, times: np.ndarray, discard_start: int = 0, discard_end: int = 0, thin: int = 1, progress: bool = True) -> np.ndarray: """Calculate total RV (planets + trend + GP) for each MCMC sample. This calculates RV_total(params_i) for each MCMC sample i, preserving parameter correlations. This differs from using median parameters which may not represent actual samples from the posterior. Parameters ---------- times : np.ndarray Time points to calculate RV at discard_start : int, optional Discard first N steps (default: 0) discard_end : int, optional Discard last N steps (default: 0) thin : int, optional Use every Nth sample (default: 1) progress : bool, optional Show progress bar (default: True) Returns ------- np.ndarray Shape (n_samples, len(times)) - Total RV for each sample """ # Calculate trend + planets at requested times (for output) total_rvs = self.calculate_rv_trend_from_samples(times, discard_start, discard_end, thin) for planet_letter in self.planet_letters: planet_rvs = self.calculate_rv_planet_from_samples(planet_letter, times, discard_start, discard_end, thin) total_rvs += planet_rvs # Calculate the GP component (conditioned on residuals from mean_rv_at_data) gp_rvs = self.calculate_rv_gp_from_samples(times=times, discard_start=discard_start, discard_end=discard_end, thin=thin, progress=progress) # Add the GP component total_rvs += gp_rvs return total_rvs
[docs] def calculate_rv_planet_custom(self, planet_letter: str, times: np.ndarray, params: dict[str, float]) -> np.ndarray: """Calculate planetary RV for a single set of custom parameters. Parameters ---------- planet_letter : str Planet letter to calculate RV for (e.g., 'b', 'c') times : np.ndarray Time points to calculate RV at params : dict[str, float] Complete parameter dictionary (free + fixed parameters). Can be created using build_params_dict(). Returns ------- np.ndarray Planetary RV values at the requested times Examples -------- >>> # Using MAP result >>> map_result = gpfitter.find_map_estimate() >>> params = gpfitter.build_params_dict(map_result.x) >>> planet_rv = gpfitter.calculate_rv_planet_custom('b', times, params) >>> >>> # Using best lnprob sample >>> best_params = gpfitter.get_sample_with_best_lnprob(discard_start=1000) >>> params = gpfitter.build_params_dict(best_params) >>> planet_rv = gpfitter.calculate_rv_planet_custom('b', times, params) """ # Extract planet parameters planet_params = {} for par in self.parameterisation.pars: key = f"{par}_{planet_letter}" planet_params[par] = params[key] # Calculate planet RV planet = ravest.model.Planet(planet_letter, self.parameterisation, planet_params) return planet.radial_velocity(times)
[docs] def calculate_rv_trend_custom(self, times: np.ndarray, params: dict[str, float]) -> np.ndarray: """Calculate trend RV for a single set of custom parameters. Parameters ---------- times : np.ndarray Time points to calculate RV at params : dict[str, float] Complete parameter dictionary (free + fixed parameters). Can be created using build_params_dict(). Returns ------- np.ndarray Trend RV values at the requested times Examples -------- >>> # Using MAP result >>> map_result = gpfitter.find_map_estimate() >>> params = gpfitter.build_params_dict(map_result.x) >>> trend_rv = gpfitter.calculate_rv_trend_custom(times, params) """ # Calculate trend RV (no gamma offset - that's per-instrument) trend = ravest.model.Trend(params={"gd": params["gd"], "gdd": params["gdd"]}, t0=self.t0) return trend.radial_velocity(times)
[docs] def calculate_rv_gp_custom(self, times: np.ndarray, params: dict[str, float]) -> np.ndarray: """Calculate GP component for a single set of custom parameters. The GP is conditioned on residuals (observed data - mean RV model) where the mean RV model is the sum of the systemic trend and all planetary signals. Parameters ---------- times : np.ndarray Time points to calculate GP at params : dict[str, float] Complete parameter and hyperparameter dictionary. Must include both standard parameters (for trend/planets/jit) and GP hyperparameters (for kernel). Can be created using build_params_dict(). Returns ------- np.ndarray GP RV values at the requested times Examples -------- >>> # Using MAP result >>> map_result = gpfitter.find_map_estimate() >>> params = gpfitter.build_params_dict(map_result.x) >>> gp_rv = gpfitter.calculate_rv_gp_custom(times, params) >>> >>> # Using best lnprob sample >>> best_params = gpfitter.get_sample_with_best_lnprob(discard_start=1000) >>> params = gpfitter.build_params_dict(best_params) >>> gp_rv = gpfitter.calculate_rv_gp_custom(times, params) """ # Separate params and hyperparams hyperparam_values = {name: params[name] for name in self.gp_kernel.expected_hyperparams} # Build GP kernel with per-instrument jitter kernel = self.gp_kernel.build_kernel(hyperparam_values) jit2_verr2 = np.zeros(len(self.time)) for inst in self.unique_instruments: mask = (self.instrument == inst) jit = params[f"jit_{inst}"] jit2_verr2[mask] = self.velerr[mask]**2 + jit**2 gp = GaussianProcess(kernel=kernel, X=jnp.array(self.time), diag=jnp.array(jit2_verr2)) # Calculate mean RV (Trend + Planets) at data times mean_rv_at_data = self.calculate_rv_trend_custom(self.time, params) for planet_letter in self.planet_letters: mean_rv_at_data += self.calculate_rv_planet_custom(planet_letter, self.time, params) # Subtract per-instrument gamma from data before calculating residuals vel_corrected = self.vel.copy() for inst in self.unique_instruments: mask = (self.instrument == inst) vel_corrected[mask] -= params[f"g_{inst}"] # Calculate residuals for conditioning (gamma-corrected) residuals = vel_corrected - mean_rv_at_data # Condition GP on residuals and predict at requested times _, gp_cond = gp.condition(y=jnp.array(residuals), X_test=jnp.array(times)) return np.array(gp_cond.mean)
[docs] def calculate_rv_total_custom(self, times: np.ndarray, params: dict[str, float]) -> np.ndarray: """Calculate total RV (trend + all planets + GP) for a single set of custom parameters. Useful for calculating RV with specific parameter values (e.g., best lnprob sample, median parameters, or experimental values). Parameters ---------- times : np.ndarray Time points to calculate RV at params : dict[str, float] Complete parameter and hyperparameter dictionary. Must include both standard parameters (for trend/planets/jit) and GP hyperparameters (for kernel). Can be created using build_params_dict(). Returns ------- np.ndarray Total RV values (trend + all planets + GP) at the requested times Examples -------- >>> # Using MAP result >>> map_result = gpfitter.find_map_estimate() >>> params = gpfitter.build_params_dict(map_result.x) >>> total_rv = gpfitter.calculate_rv_total_custom(times, params) """ # Calculate trend + planets total_rv = self.calculate_rv_trend_custom(times, params) for planet_letter in self.planet_letters: planet_rv = self.calculate_rv_planet_custom(planet_letter, times, params) total_rv += planet_rv # Add GP component gp_rv = self.calculate_rv_gp_custom(times, params) total_rv += gp_rv return total_rv
[docs] class GPLogPosterior: """Log posterior probability for GP MCMC sampling. Combines GP log likelihood and log priors for both parameters and hyperparameters. """ def __init__( self, planet_letters: list[str], parameterisation: Parameterisation, gp_kernel: GPKernel, priors: dict[str, Callable[[float], float]], hyperpriors: dict[str, Callable[[float], float]], fixed_params: dict[str, float], fixed_hyperparams: dict[str, float], free_params_names: list[str], free_hyperparams_names: list[str], time: np.ndarray, vel: np.ndarray, velerr: np.ndarray, t0: float, instrument: np.ndarray, unique_instruments: list[str], ) -> None: """Initialize the GPLogPosterior object. Parameters ---------- planet_letters : list[str] List of single-character planet identifiers. parameterisation : Parameterisation The orbital parameterisation to use. gp_kernel : GPKernel The Gaussian Process kernel to use. priors : dict[str, Callable[[float], float]] Dictionary mapping parameter names to their prior probability functions. hyperpriors : dict[str, Callable[[float], float]] Dictionary mapping hyperparameter names to their prior probability functions. fixed_params : dict[str, float] Dictionary of fixed parameter values. fixed_hyperparams : dict[str, float] Dictionary of fixed hyperparameter values. free_params_names : list[str] List of free parameter names to sample. free_hyperparams_names : list[str] List of free hyperparameter names to sample. time : np.ndarray Time of each observation [days]. vel : np.ndarray Radial velocity at each time [m/s]. velerr : np.ndarray Uncertainty on the radial velocity at each time [m/s]. t0 : float Reference time for the trend [days]. instrument : np.ndarray Instrument label for each observation. unique_instruments : list[str] List of unique instrument names. """ self.planet_letters = planet_letters self.parameterisation = parameterisation self.gp_kernel = gp_kernel self.priors = priors self.hyperpriors = hyperpriors self.fixed_params = fixed_params self.fixed_hyperparams = fixed_hyperparams self.free_params_names = free_params_names self.free_hyperparams_names = free_hyperparams_names self.time = time self.vel = vel self.velerr = velerr self.t0 = t0 self.instrument = instrument self.unique_instruments = unique_instruments # Create GP log-likelihood and GP log-prior objects for later self.gp_log_likelihood = GPLogLikelihood( time=self.time, vel=self.vel, velerr=self.velerr, t0=self.t0, instrument=self.instrument, unique_instruments=self.unique_instruments, planet_letters=self.planet_letters, parameterisation=self.parameterisation, gp_kernel=self.gp_kernel, ) # Create LogPrior objects for parameters and hyperparameters self.log_prior = LogPrior(self.priors) self.log_hyperprior = LogPrior(self.hyperpriors)
[docs] def _convert_params_for_prior_evaluation(self, free_params_dict: dict[str, float]) -> Dict[str, float]: """Convert free parameters for prior evaluation if needed. Parameters ---------- free_params_dict : dict Free parameters in current parameterisation Returns ------- dict Parameters with names/values converted for prior evaluation """ # Three cases: # Case 1: User is fitting in transformed parameterisation, but priors are in same transformed parameterisation # Case 2: User is fitting in default parameterisation, and priors are also in default parameterisation # Case 3: User is fitting in transformed parameterisation, but priors are in default parameterisation # Simple detection: do prior keys match our current free parameter names? prior_keys = set(self.priors.keys()) free_param_keys = set(self.free_params_names) if prior_keys == free_param_keys: # No conversion needed (Cases 1 & 2) return free_params_dict else: # Conversion needed (Case 3) - convert to default parameterisation equivalents # Start with just the non-planetary parameters that match params_for_prior = {key: value for key, value in free_params_dict.items() if key in prior_keys} all_params = self.fixed_params | free_params_dict # Convert each planet's parameters for planet_letter in self.planet_letters: # Get current planet parameters planet_params = {par: all_params[f"{par}_{planet_letter}"] for par in self.parameterisation.pars} # Convert to default parameterisation default_params = self.parameterisation.convert_pars_to_default_parameterisation(planet_params) # Add the converted parameter values for priors that need them for default_par, value in default_params.items(): default_param_key = f"{default_par}_{planet_letter}" if default_param_key in prior_keys: # Only add if we have a prior for it params_for_prior[default_param_key] = value return params_for_prior
[docs] def log_probability(self, combined_params_hyperparams: Dict[str, float]) -> float: """Calculate log posterior probability for given free parameters and hyperparameters. Parameters ---------- combined_params_hyperparams : Dict[str, float] Combined dictionary of free parameters and hyperparameters Returns ------- float Log posterior probability (log likelihood + log prior + log hyperprior) """ # Split the combined dictionary into parameters and hyperparameters free_params_dict = {name: combined_params_hyperparams[name] for name in self.free_params_names} free_hyperparams_dict = {name: combined_params_hyperparams[name] for name in self.free_hyperparams_names} # Fast fail for invalid jitter (before expensive prior/likelihood calculations) # We have to check jitter specifically because all other params will ultimately # get checked/raise Exceptions when they are used to calculate an RV. # Jitter doesn't directly contribute to calculated RV, so needs to be checked manually. _all_params_for_ll = self.fixed_params | free_params_dict for inst in self.unique_instruments: if _all_params_for_ll[f"jit_{inst}"] < 0: return -np.inf # Fast fail for invalid GP hyperparameters # This is a check for unphysical values, not for if they are within the hyperpriors or not try: all_hyperparams_values = self.fixed_hyperparams | free_hyperparams_dict self.gp_kernel._validate_hyperparams_values(all_hyperparams_values) except ValueError: return -np.inf # Evaluate priors on the free parameters. If any parameters are outside priors # (i.e. priors are infinite), then fail fast by returning -inf early (before expensive likelihood calc). # We attempt to convert free parameters (if needed) for prior evaluation # This is for if the user is fitting in transformed parameterisation, # but defining their priors in the default parameterisation try: params_for_prior = self._convert_params_for_prior_evaluation(free_params_dict) lp = self.log_prior(params_for_prior) except ValueError: # Invalid parameter conversion (e.g., unphysical eccentricity) return -np.inf if not np.isfinite(lp): return -np.inf # Evaluate hyperpriors on the free hyperparameters - fail fast if any hyperparameters are outside priors lhp = self.log_hyperprior(free_hyperparams_dict) if not np.isfinite(lhp): return -np.inf # Calculate GP log-likelihood with all parameters and hyperparameters all_params = self.fixed_params | free_params_dict all_hyperparams = self.fixed_hyperparams | free_hyperparams_dict ll = self.gp_log_likelihood(params=all_params, hyperparams=all_hyperparams) # Return combined log-posterior (log-likelihood + log-prior + log-hyperprior) logprob = ll + lp + lhp return logprob
[docs] def _negative_log_probability_for_MAP(self, combined_free_params_hyperparams_vals: list[float]) -> float: """For MAP: run __call__ only passing in a list, not dict, of params. Because scipy.optimize.minimise only takes list of values, not a dict, we need to assign the values back to their corresponding keys, and pass that to __call__(). This does not check that the values are in the correct order, it is assumed. As we're dealing with dicts, this hopefully is the case. Parameters ---------- combined_free_params_hyperparams_vals : list Combined list of free parameter and free hyperparameter values """ # Split the list back into params values and hyperparams values n_params = len(self.free_params_names) params_values = combined_free_params_hyperparams_vals[:n_params] hyperparams_values = combined_free_params_hyperparams_vals[n_params:] # Create combined dict from the names and values # (Assumes the order of names matches the order of values) params_dict = dict(zip(self.free_params_names, params_values)) hyperparams_dict = dict(zip(self.free_hyperparams_names, hyperparams_values)) combined_dict = params_dict | hyperparams_dict # Calculate *negative* log_probability (MAP is backwards from MCMC) logprob = self.log_probability(combined_dict) neg_logprob = -logprob # Handle -inf log_probability to prevent scipy RuntimeWarnings during optimisation # scipy's optimizer can't handle -inf values in arithmetic operations # (This does mean there is a non-zero chance we could end up returning a solution that doesn't satisfy the prior functions) if not np.isfinite(neg_logprob): return 1e30 # Very large finite number instead of +inf return neg_logprob
[docs] class GPLogLikelihood: """GP version of Log likelihood calculation for radial velocity data. Calculates log likelihood given RV model parameters and data, and GP hyperparameters. """ def __init__( self, time: np.ndarray, vel: np.ndarray, velerr: np.ndarray, t0: float, instrument: np.ndarray, unique_instruments: list[str], planet_letters: list[str], parameterisation: Parameterisation, gp_kernel: GPKernel, ) -> None: self.time = time self.vel = vel self.velerr = velerr self.t0 = t0 self.instrument = instrument self.unique_instruments = unique_instruments self.planet_letters = planet_letters self.parameterisation = parameterisation self.gp_kernel = gp_kernel # Convert data to JAX array for tinygp self.jax_time = jnp.array(self.time) self.jax_vel = jnp.array(self.vel) self.jax_velerr = jnp.array(self.velerr) # Precompute a per-observation integer index array (same pattern as LogLikelihood). # For each observation, store which instrument it came from as an integer: # e.g. unique_instruments = ["HARPS", "ESPRESSO"] # instrument = ["HARPS", "HARPS", "ESPRESSO", "HARPS", ...] # _instrument_indices = [0, 0, 1, 0, ...] # This lets us use JAX fancy indexing to expand per-instrument values to length-N # arrays in one operation, rather than looping with boolean mask slices. _inst_to_idx = {inst: i for i, inst in enumerate(self.unique_instruments)} self._instrument_indices = jnp.array([_inst_to_idx[inst] for inst in self.instrument]) # Precompute parameter key strings for gamma and jitter lookups. # These strings (e.g. "g_HARPS", "jit_ESPRESSO") are constant for the lifetime # of this object — precomputing them avoids rebuilding f-strings on every call. self._gamma_keys = [f"g_{inst}" for inst in self.unique_instruments] self._jitter_keys = [f"jit_{inst}" for inst in self.unique_instruments] # Precompute jax_velerr squared — constant (as observed data doesn't change) so no need to recalculate every time self._velerr_sq = self.jax_velerr ** 2
[docs] def _calculate_mean_model(self, params: Dict[str, float]) -> jnp.ndarray: """Calculate the Keplerian RV model (the mean function for the GP). Takes planetary parameters, trend parameters, and per-instrument gamma offsets. Parameters ---------- params : Dict[str, float] Dictionary of all parameter values Returns ------- jnp.ndarray Mean model RV values at observation times """ rv_total = jnp.zeros(len(self.time)) # Step 1: Calculate RV contributions from each planet for letter in self.planet_letters: # get just the parameters for this planet (and strip the _letter suffix from the keys) _this_planet_params = { par: params[f"{par}_{letter}"] for par in self.parameterisation.pars } try: _this_planet = ravest.model.Planet(letter, self.parameterisation, _this_planet_params) _this_planet_rv = _this_planet.radial_velocity(self.time) except ValueError: # Planet.__init__ validates parameters and raises ValueError for invalid params return -np.inf # fail-fast: return -inf log-likelihood # add this planet's RV contribution to the total rv_total += _this_planet_rv # Step 2: Calculate and add the RV from the system Trend (no gamma - that's per-instrument) _trend_params = {"gd": params["gd"], "gdd": params["gdd"]} _this_trend = ravest.model.Trend(params=_trend_params, t0=self.t0) _rv_trend = _this_trend.radial_velocity(self.time) rv_total += jnp.array(_rv_trend) # Step 3: Add per-instrument gamma offsets using vectorised fancy indexing. # Build a small array of gamma values, one per instrument (length K), then use # _instrument_indices to select the right gamma for each of the N observations. # JAX arrays are immutable so we use addition rather than in-place update. gamma_per_instrument = jnp.array([params[k] for k in self._gamma_keys]) gamma_at_each_obs = gamma_per_instrument[self._instrument_indices] rv_total = rv_total + gamma_at_each_obs return rv_total
[docs] @staticmethod @jax.jit def _compute_gp_log_likelihood( kernel: kernels.Kernel, time_array: jnp.ndarray, vel_array: jnp.ndarray, verr_squared_array: jnp.ndarray, mean_model: jnp.ndarray ) -> float: """JIT-compiled GP log likelihood computation. This is the expensive numerical part that benefits from JIT compilation. """ gp = GaussianProcess(kernel=kernel, X=time_array, diag=verr_squared_array) residuals = vel_array - mean_model return gp.log_probability(y=residuals)
[docs] def __call__(self, params: Dict[str, float], hyperparams: Dict[str, float]) -> float: """Calculate GP log likelihood for given parameters and hyperparameters. Parameters ---------- params : Dict[str, float] Dictionary of all parameter values hyperparams : Dict[str, float] Dictionary of all hyperparameter values Returns ------- float Log likelihood value """ # Calculate mean model (RV signal from planets + system trend + per-instrument gamma) mean_model = self._calculate_mean_model(params) # Check if mean model calculation failed # (no point doing expensive GP calculation if we don't need to) if not jnp.isfinite(mean_model).all(): return -np.inf # Build GP kernel with hyperparameters kernel = self.gp_kernel.build_kernel(hyperparams) # Add per-instrument jitter to observational uncertainties using vectorised fancy indexing. # Each instrument has its own jitter value. We need to pair each of the N observations # with its instrument's jitter. We do this in two steps: # 1. Build a small array of jitter values, one per instrument (length K) # 2. Use _instrument_indices to select the right jitter for each observation (length N) # N.B. we don't sqrt here - tinygp diag wants variance, not stddev jitter_per_instrument = jnp.array([params[k] for k in self._jitter_keys]) jitter_at_each_obs = jitter_per_instrument[self._instrument_indices] velerr_jit_squared = self._velerr_sq + jitter_at_each_obs**2 # Use JIT-compiled helper for the expensive GP computation return self._compute_gp_log_likelihood( kernel=kernel, time_array=self.jax_time, vel_array=self.jax_vel, verr_squared_array=velerr_jit_squared, mean_model=mean_model )