import os
import sys
import logging
from typing import Union, Optional, Dict, List, Callable, Any
import configparser
import numpy as np
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
# pyABC imports
from pyabc import ABCSMC, sampler, LocalTransition, AdaptiveAggregatedDistance, AggregatedDistance, QuantileEpsilon, History, RV, Distribution
from pyabc.populationstrategy import AdaptivePopulationSize
from pyabc.storage import load_dict_from_json
from dask.distributed import Client, get_worker
# UQ PhysiCell imports
from uq_physicell import PhysiCell_Model
from uq_physicell.abc.utils import insert_adaptive_weights_db, insert_metadata_db
from uq_physicell.utils import run_replicate_serializable
from ..utils.sumstats import _convert_qoi_function_to_string
[docs]
class CalibrationContext:
"""
Context for Approximate Bayesian Computation (ABC) calibration using pyABC.
This class encapsulates all necessary parameters and configurations for model calibration
using ABC-SMC with sophisticated handling of multiple models, parallel computation,
and adaptive strategies.
Args:
db_path (str): Path to the database file for storing and retrieving calibration results.
obsData (str or dict): Path to observed data CSV file or dictionary containing observed data.
obsData_columns (dict): Dictionary mapping QoI names to their corresponding columns in the observed data.
model_config (dict): Configuration dictionary for the PhysiCell model, including paths and structure names.
qoi_functions (dict): Dictionary of functions to compute quantities of interest (QoIs) from model outputs.
qoi_def (dict): first-class object, that can be used in qoi_functions lambda string, mapped to their name.
distance_functions (dict): Dictionary of distance functions with their weights for comparing model outputs to observed data.
prior (Distribution): Distribution defining the prior distributions for parameters
abc_options (dict): Options for ABC-SMC including population parameters, sampling strategies, and convergence criteria.
logger (logging.Logger): Logger instance for logging messages during the calibration process.
"""
def __init__(
self,
db_path: str,
obsData: Union[str, dict],
obsData_columns: dict,
model_config: dict,
qoi_functions: dict,
distance_functions: dict,
prior: Distribution,
abc_options: dict,
qoi_def:dict={},
logger: Optional[logging.Logger] = None
):
"""Initialize CalibrationContext with comprehensive validation and setup."""
# Core configuration
self.db_path = db_path
self.model_config = model_config
# QOI_FUNCTIONS MUST BE STRINGS, BECAUSE THEY NEED TO BE SERIALIZABLE TO BE SAVED IN THE DATABASE AND USED IN THE DEFAULT AGGREGATION FUNCTION.
self.qoi_functions = {key: _convert_qoi_function_to_string(value, key) if not isinstance(value, str) else value for key, value in qoi_functions.items()}
self.qoi_def = qoi_def
self.distance_functions = distance_functions
self.prior = prior
self.abc_options = abc_options
# Setup logger
if logger is None:
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
if not self.logger.handlers:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
else:
self.logger = logger
# Load and validate observed data
if isinstance(obsData, dict):
self.dic_obsData = obsData
self.obsData_path = None
else: # obsData is a path
try:
self.obsData_path = obsData
self.dic_obsData = pd.read_csv(obsData).to_dict('list')
# Replace column names according to obsData_columns mapping
for qoi, column_name in obsData_columns.items():
if column_name in self.dic_obsData:
self.dic_obsData[qoi] = np.array(self.dic_obsData.pop(column_name), dtype=np.float64)
else:
raise ValueError(f"Column '{column_name}' not found in observed data.")
self.logger.debug(f"Successfully loaded observed data from {obsData}")
except Exception as e:
self.logger.error(f"Error reading observed data from {obsData}: {e}")
raise
# ABC-SMC configuration
self.max_populations = abc_options.get("max_populations", 20)
self.max_simulations = abc_options.get("max_simulations", 1000)
self.population_strategy = abc_options.get("population_strategy", "adaptive")
self.min_population_size = abc_options.get("min_population_size", 100)
self.max_population_size = abc_options.get("max_population_size", 500)
self.epsilon_strategy = abc_options.get("epsilon_strategy", "quantile")
self.epsilon_alpha = abc_options.get("epsilon_alpha", 0.5) # Quantile for epsilon threshold (e.g., 0.5 for median distance)
# Transition and distance configuration
self.transition_strategy = abc_options.get("transition_strategy", "multivariate") # "local" or "multivariate"
self.adaptive_distance = abc_options.get("adaptive_distance", False)
self.adaptive_distance_file = abc_options.get("adaptive_distance_file", None if not self.adaptive_distance else "adaptive_distance_log.json")
self.convergence_check_func = abc_options.get("convergence_check_func", None)
# Sampling configuration
self.sampler_type = abc_options.get("sampler", "multicore") # "dask" or "multicore"
self.cluster_setup_func = abc_options.get("cluster_setup_func", None) # Function to setup Dask cluster
self.num_workers = abc_options.get("num_workers", os.cpu_count())
# Model configuration
self.num_replicates = self.model_config.get('numReplicates', None)
# Read num_replicates from ini file if not provided in model_config
if self.num_replicates is None:
configFile = configparser.ConfigParser()
configFile.read_file(open(model_config['ini_path']))
self.num_replicates = int(configFile[model_config['struc_name']]['numReplicates'])
self.fixed_params = abc_options.get('fixed_params', {})
self.summary_function = abc_options.get("summary_function", None)
self.aggregation_func = abc_options.get("custom_aggregation_func", self._default_aggregation_func)
self.custom_run_single_replicate_func = abc_options.get("custom_run_single_replicate_func", None)
# Parameter scaling
self.log_scale = abc_options.get("log_scale", False)
# Multiple models support
self.num_models = abc_options.get("num_models", 1)
self.model_selection = abc_options.get("model_selection", False)
# Parallelization setup
self._setup_parallelization()
# Validate configuration
self._validate_configuration()
self.logger.info(f"🔧 CalibrationContext initialized for ABC-SMC calibration")
self.logger.info(f"📊 Database: {self.db_path}")
self.logger.info(f"🎯 QoIs: {list(self.qoi_functions.keys())}")
self.logger.info(f"🔍 Parameters: {self.prior.get_parameter_names()}")
self.logger.info(f"⚙️ Sampler: {self.sampler_type} with {self.num_workers} workers")
def _setup_parallelization(self):
"""Setup parallelization strategy based on configuration."""
if self.sampler_type == "dask":
self.workers_inner = None # Dask handles its own parallelization
self.workers_outer = self.num_workers
elif self.sampler_type == "multicore":
# Calculate nested parallelization for multicore
self.workers_inner = min(self.num_workers, self.num_replicates) # workers for replicates
self.workers_outer = max(1, self.num_workers // self.workers_inner) # workers for parameter sets
else:
raise ValueError(f"Sampler {self.sampler_type} is not supported. Use 'dask' or 'multicore'.")
def _validate_configuration(self):
"""Validate the configuration parameters."""
required_model_keys = ['ini_path', 'struc_name']
for key in required_model_keys:
if key not in self.model_config:
raise ValueError(f"Missing required model_config key: {key}")
if not self.qoi_functions:
raise ValueError("qoi_functions cannot be empty")
if not self.distance_functions:
raise ValueError("distance_functions cannot be empty")
if not self.prior:
raise ValueError("prior cannot be empty")
# Validate QoI consistency
for qoi in self.qoi_functions.keys():
if qoi not in self.distance_functions:
raise ValueError(f"Distance function not defined for QoI: {qoi}")
[docs]
def setup_sampler(self, cluster_setup_func=None):
"""Setup the pyABC sampler based on configuration."""
if self.sampler_type == "dask":
if cluster_setup_func is None:
raise ValueError("cluster_setup_func must be provided for cluster mode with Dask")
cluster = cluster_setup_func()
my_sampler = sampler.DaskDistributedSampler(Client(cluster))
self.logger.info(f"Using Dask sampler with {self.num_workers} workers")
elif self.sampler_type == "multicore":
my_sampler = sampler.MulticoreParticleParallelSampler(n_procs=self.workers_outer)
self.logger.info(f"Using nested multicore parallelization: {self.workers_outer} outer processes × {self.workers_inner} inner threads = {self.workers_outer * self.workers_inner} total")
else:
raise ValueError(f"Sampler {self.sampler_type} is not supported.")
return my_sampler
[docs]
def setup_population_strategy(self):
"""Setup population size strategy."""
if self.population_strategy == "adaptive":
return AdaptivePopulationSize(
start_nr_particles=self.max_population_size,
mean_cv=0.2,
min_population_size=self.min_population_size,
max_population_size=self.max_population_size
)
else:
# Fixed population size - use max_population_size as the fixed size
return self.max_population_size
[docs]
def setup_distance_function(self, distances_dict):
qois = list(self.qoi_functions.keys())
distance_funcs = [distance_func["function"] for distance_func in distances_dict.values()]
distance_weights = [distance_func.get("weight", 1.0) for distance_func in distances_dict.values()]
"""Setup distance function with optional adaptive weighting."""
if len(qois) > 1:
if self.adaptive_distance:
# If adaptive distance, try to load previous weights if database exists
if os.path.exists(self.db_path) and hasattr(self, '_load_adaptive_weights'):
try:
distance_weights = self._load_adaptive_weights(self.db_path)
except Exception as e:
raise ValueError(f"Could not load adaptive weights: {e}")
# Starting with loaded adaptive weights.
distance_func = AdaptiveAggregatedDistance(distance_funcs, initial_weights=distance_weights, log_file=self.adaptive_distance_file)
else:
# Start with adaptive weighting
distance_func = AdaptiveAggregatedDistance(distance_funcs, log_file=self.adaptive_distance_file)
self.logger.info("Starting with adaptive distance weighting")
else:
# Use provided weights or equal weights
distance_func = AggregatedDistance(distance_funcs, weights=distance_weights)
self.logger.info(f"Using fixed distance weights: {distance_weights}")
else:
distance_func = distance_funcs[0]
self.logger.info(f"Using single distance function: {distances_dict[qois[0]]['function']}")
return distance_func
[docs]
def setup_transition_function(self):
"""Setup transition function for ABC-SMC."""
if self.transition_strategy == "local":
return LocalTransition()
else:
return None # Default - multivariate normal transitions
[docs]
def setup_epsilon_function(self):
"""Setup epsilon (tolerance) function for ABC-SMC."""
if self.epsilon_strategy == "quantile":
return QuantileEpsilon(alpha=self.epsilon_alpha)
else:
raise ValueError(f"Epsilon strategy '{self.epsilon_strategy}' not supported")
[docs]
def create_model_wrapper(self, fixed_params_dict, workers_inner=None):
"""Create wrapper function for PhysiCell model evaluation."""
def model_wrapper(pars):
return self._run_physicell_model(pars, fixed_params_dict, workers_inner)
return model_wrapper
def _run_physicell_model(self, pars, fixed_params, workers_inner=None):
"""Run PhysiCell model with given parameters."""
try:
# Convert parameters from log scale if needed
if self.log_scale and hasattr(self, '_convert_params_to_linear_scale'):
pars = self._convert_params_to_linear_scale(pars)
# Choose parallelization strategy based on sampler
if self.sampler_type == 'multicore' and workers_inner is not None:
return self._run_replicates_parallel(workers_inner, pars, fixed_params)
else:
return self._run_physicell_model_sequential(pars, fixed_params)
except Exception as e:
self.logger.error(f"Error in model evaluation: {e}")
raise ValueError(f"Error in model evaluation: {e}")
def _default_aggregation_func(self, replicate_results):
"""Define function to aggregate the replicates"""
try:
results_df = pd.concat(list(replicate_results.values()), ignore_index=True)
# Take the mean of all columns in the same sampleID and time
return results_df.pivot_table(index=['sampleID','time'])
except Exception as e:
raise ValueError(f"Error in _default_aggregation_func for sampleID: {replicate_results.values()[0]['sampleID'].unique()}")
def _run_physicell_model_sequential(self, pars, fixed_params, sample_id=None, replicate_id=None):
"""Run PhysiCell model sequentially."""
# Create the PhysiCell model instance for each worker
physicell_model = PhysiCell_Model(self.model_config['ini_path'], self.model_config['struc_name'])
physicell_model.numReplicates = self.num_replicates
# Configure input/output folders
if "input_folder" in self.model_config:
physicell_model.input_folder += self.model_config['input_folder']
if "output_folder" in self.model_config:
physicell_model.output_folder += self.model_config['output_folder']
physicell_model.timeout = 600
physicell_model.output_summary_Path = None
# Get parameter names
params_xml = [param_name for param_name in physicell_model.XML_parameters_variable.values()]
params_rules = [param_name for param_name in physicell_model.parameters_rules_variable.values()]
# Get worker ID
if sample_id is None:
sample_id = self._get_worker_id()
# Prepare parameters
dic_pars_xml = {par: (pars[par] if par in pars.keys() else None) for par in params_xml}
dic_pars_rules = {par: (pars[par] if par in pars.keys() else None) for par in params_rules}
# Include fixed parameters
for par in fixed_params.keys():
if par in dic_pars_xml.keys():
dic_pars_xml[par] = fixed_params[par]
if par in dic_pars_rules.keys():
dic_pars_rules[par] = fixed_params[par]
# Validation
if None in dic_pars_xml.values() or None in dic_pars_rules.values():
raise ValueError(f"Some parameters are None: {dic_pars_xml}, {dic_pars_rules}")
# Run replicates
replicates = range(self.num_replicates) if replicate_id is None else [replicate_id]
dic_all_replicates = {}
for replicate_id in replicates:
try:
_, _, result_data = run_replicate_serializable(
PhysiCellModel_conf=self.model_config,
sample_id=sample_id,
replicate_id=replicate_id,
ParametersXML=dic_pars_xml,
ParametersRules=dic_pars_rules,
qoi_functions=self.qoi_functions,
qoi_def=self.qoi_def,
return_binary_output=False,
#drop_columns,
custom_summary_function=self.summary_function,
)
dic_all_replicates[replicate_id] = result_data
except Exception as e:
raise RuntimeError(f"Error in RunModel (SampleID: {sample_id}): {e}")
# Check if RunModel returned valid data
if not hasattr(result_data, 'columns') or len(result_data) == 0:
raise RuntimeError(f"RunModel returned empty or invalid DataFrame for SampleID: {sample_id}, ReplicateID: {replicate_id}")
# All replicates done, run aggregation function
if replicate_id is None:
return self.aggregation_func(dic_all_replicates)
else:
return dic_all_replicates
def _run_replicates_parallel(self, workers_inner, params, fixed_params):
"""Run replicates in parallel using ThreadPoolExecutor."""
with ThreadPoolExecutor(max_workers=workers_inner) as executor:
futures = []
for replicate_id in range(self.num_replicates):
future = executor.submit(
self._run_physicell_model_sequential,
params, fixed_params, self._get_worker_id(), replicate_id
)
futures.append(future)
dic_all_replicates = {}
for future_done in as_completed(futures):
dict_result = future_done.result(timeout=0.5)
for key, value in dict_result.items():
dic_all_replicates[key] = value
return self.aggregation_func(dic_all_replicates)
def _get_worker_id(self):
"""Get worker ID for distributed computing."""
try:
worker_id = int(get_worker().name)
except:
try:
worker_id = int(get_worker().name.split("-")[-1])
except:
worker_id = os.getpid()
return worker_id
[docs]
def setup_abc_smc(self, models_list, priors_list, distance_function, population_size, transitions_func, my_sampler, eps_function):
"""Setup the ABC-SMC object."""
return ABCSMC(
models=models_list,
parameter_priors=priors_list,
distance_function=distance_function,
population_size=population_size,
transitions=transitions_func,
sampler=my_sampler,
eps=eps_function
)
[docs]
def load_or_create_database(self, abc_smc, abc_id=1):
"""Load existing database or create new one."""
db_file = "sqlite:///" + os.path.join(self.db_path)
if os.path.exists(self.db_path):
try:
abc_smc.load(db_file, abc_id=abc_id)
self.logger.info(f"Loaded existing database: {self.db_path}")
return True, abc_smc.history.n_populations, abc_smc.history.total_nr_simulations
except ValueError as e:
self.logger.error(f"Error loading database {db_file}: {e}")
raise
else:
abc_smc.new(db_file, observed_sum_stat=self.dic_obsData)
self.logger.info(f"Created new database: {self.db_path}")
return False, 0, 0
[docs]
def run_calibration(self, abc_smc, resume_db=False, current_populations=0, current_simulations=0):
"""Run the ABC-SMC calibration."""
if resume_db:
extra_populations = max(0, self.max_populations - current_populations)
extra_simulations = max(0, self.max_simulations - current_simulations)
self.logger.info(f"Resuming: extra populations: {extra_populations}, extra simulations: {extra_simulations}")
if extra_populations > 0 and extra_simulations > 0:
abc_smc.run(max_nr_populations=self.max_populations, max_total_nr_simulations=self.max_simulations)
else:
self.logger.info("No additional calibration needed")
else:
self.logger.info(f"Starting calibration: max populations: {self.max_populations}, max simulations: {self.max_simulations}")
abc_smc.run(max_nr_populations=self.max_populations, max_total_nr_simulations=self.max_simulations)
# Add metadata to database
insert_metadata_db(self.db_path, self)
# Add extra info of adaptive distance to database
if self.adaptive_distance:
insert_adaptive_weights_db(self.db_path, dict_distances=self.distance_functions, dict_adaptive_weights=load_dict_from_json(self.adaptive_distance_file))
[docs]
def check_convergence(self, abc_smc):
"""Check convergence criteria."""
# This would need the check_convergence_generic function
# For now, return False to continue until max iterations
return False
[docs]
def include_additional_metadata(self, **metadata):
"""Include additional metadata in the database."""
# This would need the include_additional_data_in_db function
# For now, just log the metadata
self.logger.info(f"Additional metadata: {metadata}")
[docs]
def run_abc_calibration( calib_context: CalibrationContext) -> History:
"""
Execute the complete ABC-SMC calibration process.
This function orchestrates the entire ABC-SMC workflow using the CalibrationContext,
including sampler setup, distance function configuration, model wrapper creation,
and calibration execution with convergence checking.
Args:
calib_context (CalibrationContext): The calibration context containing all configuration
Returns:
History: The pyABC History object containing calibration results
"""
logger = calib_context.logger
try:
logger.info("🚀 Starting ABC-SMC calibration process")
logger.info(f"📊 Database: {calib_context.db_path}")
logger.info(f"🎯 Max populations: {calib_context.max_populations}")
logger.info(f"🔬 Max simulations: {calib_context.max_simulations}")
# Setup sampler
logger.info("⚙️ Setting up sampler...")
my_sampler = calib_context.setup_sampler(calib_context.cluster_setup_func)
# Setup population strategy
logger.info("👥 Setting up population strategy...")
population_size = calib_context.setup_population_strategy()
# Setup distance function
logger.info("📏 Setting up distance function...")
distance_function = calib_context.setup_distance_function(calib_context.distance_functions)
# Setup transition function
logger.info("🔄 Setting up transition function...")
transitions_func = calib_context.setup_transition_function()
# Setup epsilon function
logger.info("🎯 Setting up epsilon function...")
eps_function = calib_context.setup_epsilon_function()
# Setup model wrappers
logger.info("🧬 Setting up model wrappers...")
# Handle multiple models if specified
if calib_context.model_selection and calib_context.num_models > 1:
# This would require additional logic for multiple models
# For now, use single model approach
logger.info("⚠️ Multiple model selection not fully implemented, using single model")
# Create model wrapper
model_wrapper = calib_context.create_model_wrapper(
calib_context.fixed_params,
calib_context.workers_inner
)
models_list = [model_wrapper]
priors_list = [calib_context.prior]
# Setup ABC-SMC object
logger.info("🔧 Setting up ABC-SMC object...")
abc_smc = calib_context.setup_abc_smc(
models_list=models_list,
priors_list=priors_list,
distance_function=distance_function,
population_size=population_size,
transitions_func=transitions_func,
my_sampler=my_sampler,
eps_function=eps_function
)
# Load or create database
logger.info("💾 Managing database...")
resume_db, current_populations, current_simulations = calib_context.load_or_create_database(abc_smc)
# Run calibration
logger.info("🎲 Starting calibration run...")
calib_context.run_calibration(abc_smc, resume_db, current_populations, current_simulations)
# Check convergence and run additional populations if needed
if calib_context.convergence_check_func is not None and calib_context.mode == 'cluster':
logger.info("🔍 Checking convergence...")
while True:
if calib_context.convergence_check_func(abc_smc.history):
logger.info("✅ Convergence achieved!")
break
# If max limits reached, extend by one to run one more population
if abc_smc.history.total_nr_simulations >= calib_context.max_simulations:
calib_context.max_simulations = abc_smc.history.total_nr_simulations + 1
if abc_smc.history.n_populations >= calib_context.max_populations:
calib_context.max_populations = abc_smc.history.n_populations + 1
logger.info(f"🔄 Continuing calibration: {calib_context.max_simulations} simulations, {calib_context.max_populations} populations")
# Run one more iteration
abc_smc.run(
max_nr_populations=calib_context.max_populations,
max_total_nr_simulations=calib_context.max_simulations
)
# Add extra info of adaptive distance to database
if calib_context.adaptive_distance:
insert_adaptive_weights_db(calib_context.db_path, dict_distances=calib_context.distance_functions, dict_adaptive_weights=load_dict_from_json(calib_context.adaptive_distance_file))
# Remove temporary file and folders
physicell_model = PhysiCell_Model(calib_context.model_config['ini_path'], calib_context.model_config['struc_name'])
physicell_model.remove_io_folders()
# Final results
final_history = abc_smc.history
logger.info("✅ ABC-SMC calibration completed successfully!")
logger.info(f"📈 Final statistics:")
logger.info(f" - Populations: {final_history.n_populations}")
logger.info(f" - Total simulations: {final_history.total_nr_simulations}")
logger.info(f" - Database: {calib_context.db_path}")
return final_history
except Exception as e:
logger.error(f"❌ Error in ABC-SMC calibration: {e}")
raise