Source code for uq_physicell.model_analysis.ma_context

import numpy as np
import os, sys, io
import logging
import pickle
import traceback
import signal
import threading
from contextlib import redirect_stdout

# Check if the required libraries are available - futures and mpi4py
try:
    from concurrent.futures import ProcessPoolExecutor
    # Collect results and write to the database
    from concurrent.futures import TimeoutError, as_completed
    futures_available = True
except ImportError:
    futures_available = False
try:
    from mpi4py import MPI
    mpi_available = True
except ImportError:
    mpi_available = False

# My local modules
from uq_physicell import PhysiCell_Model
from .samplers import run_local_sampler, run_global_sampler
from ..utils.model_wrapper import run_replicate, run_replicate_serializable
from ..utils.sumstats import _convert_qoi_function_to_string, recreate_qoi_functions
from ..database.ma_db import create_structure, insert_metadata, insert_param_space, insert_qois, insert_samples, insert_output, check_simulations_db, _disable_wal_mode


[docs] class ModelAnalysisContext: """Context manager for running PhysiCell model analysis simulations. This class manages the configuration, database setup, and execution context for running sensitivity analysis and uncertainty quantification simulations on PhysiCell models. Args: db_path (str): Path to the SQLite database file for storing results. model_config (dict): Dictionary containing PhysiCell model configuration. Must include 'ini_path' and 'struc_name' keys. sampler (str): Name of the sampling method to use (e.g., 'LHS', 'Sobol', 'OAT'). params_info (dict): Dictionary containing parameter definitions with keys for each parameter name and values containing 'ref_value', 'lower_bound', 'upper_bound', and 'perturbation' information. qois_info (dict): Dictionary containing Quantities of Interest definitions. qoi_def (dict): first-class object, that can be used in qoi_functions lambda string, mapped to their name. parallel_method (str, optional): Parallelization method. Options are: 'inter-process' (single node), 'inter-node' (MPI), or 'serial'. Defaults to 'inter-process'. num_workers (int, optional): Number of parallel workers for inter-process execution. Defaults to 1. summary_function (callable, optional): Custom function for summarizing simulation output. Defaults to None. Raises: ImportError: If required parallelization libraries are not available. ValueError: If invalid parallel_method is specified. """ def __init__( self, db_path:str, model_config:dict, sampler:str, params_info:dict, qois_info:dict, qoi_def:dict={}, parallel_method:str='inter-process', num_workers:int=1, summary_function=None, logger: logging.Logger=None): self.db_path = db_path self.params_dict = params_info # Accept model_config as (ini_path, key) tuple or {'ini_path': ..., 'struc_name': ...} dict if isinstance(model_config, (tuple, list)): model_config = {'ini_path': model_config[0], 'struc_name': model_config[1]} # Check free variables before converting to strings — Python's eval() is lazy # and only detects missing names at call time, not at lambda-creation time. # co_freevars reveals closure references that will fail in worker processes. for qoi_name, func in qois_info.items(): if callable(func) and hasattr(func, '__code__'): unresolved = set(func.__code__.co_freevars) - set(qoi_def.keys()) if unresolved: raise ValueError( f"QoI '{qoi_name}': lambda closes over {sorted(unresolved)} which " f"cannot be serialized for multiprocessing. Pass via " f"qoi_def={{name: object}} in ModelAnalysisContext." ) # QoI functions are stored as source strings so they can be pickled across processes self.qois_dict = {key: _convert_qoi_function_to_string(value, key) if not isinstance(value, str) else value for key, value in qois_info.items()} self.qoi_def = qoi_def # Secondary check: verify string-form QoIs can be eval'd in the restricted namespace if self.qois_dict: try: recreate_qoi_functions(self.qois_dict, self.qoi_def) except Exception as e: raise ValueError( f"A QoI function cannot be serialized for multiprocessing. " f"If it closes over an external variable or function, pass it via " f"qoi_def={{name: object}} in ModelAnalysisContext. Details: {e}" ) from None self.parallel_method = parallel_method self.num_workers = num_workers self.summary_function = summary_function self.logger = logger if logger is not None else logging.getLogger(__name__) # Initialize cancellation flag and process tracking self._cancellation_requested = False self.futures = [] self.model = None # Will be set in run_simulations # Initialize metadata for database self.dic_metadata = { 'Sampler': sampler, 'IniFilePath': model_config['ini_path'], 'StrucName': model_config['struc_name'] } # Validation of the selected parallelization method if self.parallel_method == 'inter-node': if not mpi_available: raise ImportError("mpi4py is not available. Please install mpi4py or set parallel_method='inter-process'.") elif self.parallel_method == 'inter-process': if not futures_available: raise ImportError("concurrent.futures is not available. Please install futures or set parallel_method='inter-node'.") elif self.parallel_method == 'serial': self.num_workers = 1 else: raise ValueError("Invalid parallel_method. Use 'inter-node' for MPI, 'inter-process' for futures, or 'serial' for single process.") _GLOBAL_SAMPLERS = frozenset({ 'Sobol', 'Latin hypercube sampling (LHS)', 'Fast', 'Fractional Factorial', 'Finite Difference', }) def _validate_params_for_sampler(self): """Check that params_info contains the fields required by the chosen sampler and warn about fields that will be silently ignored.""" sampler = self.dic_metadata['Sampler'] if sampler in self._GLOBAL_SAMPLERS: for param, props in self.params_dict.items(): missing = [f for f in ('lower_bound', 'upper_bound') if props.get(f) is None] if missing: raise ValueError( f"Parameter '{param}' is missing {missing} required by " f"sampler='{sampler}'. Global samplers use only " f"'lower_bound' and 'upper_bound' for sampling." ) if props.get('perturbation') is not None: self.logger.warning( f"Parameter '{param}': 'perturbation' is not used by " f"sampler='{sampler}' and will be ignored — only " f"'lower_bound' and 'upper_bound' affect global sampling." ) elif sampler == 'OAT': for param, props in self.params_dict.items(): if props.get('ref_value') is None: raise ValueError( f"Parameter '{param}' is missing 'ref_value' required by OAT sampling." ) is_bool = props.get('type') == 'bool' if not is_bool and props.get('perturbation') is None: raise ValueError( f"Parameter '{param}' is missing 'perturbation' required by OAT sampling. " f"Provide a list of percentage values, e.g. [1.0, 5.0, 10.0]." ) for unused in ('lower_bound', 'upper_bound'): if props.get(unused) is not None: self.logger.warning( f"Parameter '{param}': '{unused}' is not used by OAT sampling " f"(only 'ref_value' and 'perturbation' are used) and will be ignored." )
[docs] def generate_samples(self, N: int = None, M: int = 4, seed: int = 42): self._validate_params_for_sampler() if (self.dic_metadata['Sampler'] == 'OAT'): self.dic_samples = run_local_sampler(self.params_dict, self.dic_metadata['Sampler']) elif (self.dic_metadata['Sampler'] != 'User-defined'): self.dic_samples = run_global_sampler(self.params_dict, self.dic_metadata['Sampler'], N, M, seed)
[docs] def set_samples(self, samples): """Set user-defined parameter combinations for the 'User-defined' sampler. Args: samples: dict mapping integer sample IDs to parameter dicts, e.g. {0: {'param1': 1.0, 'param2': 2.0}, 1: {'param1': 1.5, 'param2': 3.0}} or a list of parameter dicts (IDs are assigned automatically starting from 0), e.g. [{'param1': 1.0}, {'param1': 1.5}] """ if isinstance(samples, list): self.dic_samples = {i: s for i, s in enumerate(samples)} else: self.dic_samples = samples
[docs] def run(self): """Run simulations — convenience alias for ``run_simulations(context)``. Allows the fluent pattern:: context.generate_samples(N=8) context.run() """ run_simulations(self)
[docs] def cancelled(self): """Check if cancellation has been requested. Returns: bool: True if cancellation was requested, False otherwise """ return self._cancellation_requested
[docs] def request_cancellation(self): """Request cancellation of all simulations. This sets the internal cancellation flag to True, which will be checked by the simulation process at various points. Returns: bool: Always returns True """ self.logger.info("Cancellation requested") self._cancellation_requested = True # If we have a model instance and it has active processes, terminate them if hasattr(self, 'model') and self.model is not None: if hasattr(self.model, 'terminate_all_simulations'): self.logger.info("Terminating all active simulations...") results = self.model.terminate_all_simulations() for process_id, return_code in results.items(): self.logger.info(f"Process {process_id} terminated with return code {return_code}") # Cancel futures if they exist if self.parallel_method == 'inter-process' and hasattr(self, 'futures'): for future in self.futures: if not future.done() and not future.cancelled(): future.cancel() self.logger.info(f"Cancelled future {future}") return True
[docs] def run_simulations(context: ModelAnalysisContext): """Run PhysiCell simulations based on the provided analysis context. This function executes sensitivity analysis simulations using the specified parallelization method (serial, inter-process, or MPI). It manages database initialization, parameter sampling, simulation execution, and result storage. Args: context (ModelAnalysisContext): The analysis context containing model configuration, sampling parameters, parallelization settings, and database information. Raises: ValueError: If there are issues with PhysiCell model initialization, database operations, or simulation execution. ImportError: If required parallelization libraries are missing. Note: This function handles three execution modes: - Serial: Single-threaded execution for small analyses - Inter-process: Multi-processing on a single node using concurrent.futures - Inter-node: Distributed execution across multiple nodes using MPI """ # Only set up signal handlers if we're in the main thread of the main interpreter if threading.current_thread() is threading.main_thread(): def signal_handler(sig, frame): context.logger.info(f"Received signal {sig}, initiating graceful shutdown") context.request_cancellation() # If it's a keyboard interrupt and we're in the main process, exit if sig == signal.SIGINT and context.parallel_method != 'inter-process': sys.exit(0) # Register signal handlers if context.parallel_method != 'inter-node': # Don't override MPI's own signal handling signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) # Initialize the parallelization method if context.parallel_method == 'inter-node': use_mpi = True use_futures = False comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() elif context.parallel_method == 'inter-process': use_mpi = False use_futures = True rank = 0 else: #context.parallel_method == 'serial': use_mpi = False use_futures = False rank = 0 # Initialize the PhysiCell model - in all ranks to avoid issues with MPI try: PhysiCellModel = PhysiCell_Model(context.dic_metadata['IniFilePath'], context.dic_metadata['StrucName']) # Store the model in the context for cancellation support context.model = PhysiCellModel except Exception as e: context.logger.error(f"Error initializing PhysiCell model: {e}") raise # Initialize or load the database structure if rank == 0: # Capture PhysiCell model info output and log it info_buffer = io.StringIO() with redirect_stdout(info_buffer): PhysiCellModel.info() info_output = info_buffer.getvalue().strip() if info_output: context.logger.info(f"PhysiCell Model Information:\n{info_output}") # Check if the db file already exists try: exist_db, All_Parameters, All_Samples, All_Replicates = check_simulations_db(PhysiCellModel, context.dic_metadata['Sampler'], context.params_dict, context.dic_samples, context.qois_dict, context.db_path) except Exception as e: context.logger.error(f"Error checking existing database {context.db_path}: {e}") raise # Remove the output folder - to avoid overwriting if os.path.exists(PhysiCellModel.output_folder): os.system('rm -rf ' + PhysiCellModel.output_folder) if not exist_db: # Initialize database structure context.logger.info(f"Creating database structure in {context.db_path}") try: create_structure(context.db_path) except Exception as e: context.logger.error(f"Error creating database structure: {e}") raise # Insert metadata context.logger.info(f"Inserting metadata, parameter space, and QoIs into the database") try: insert_metadata(context.db_path, context.dic_metadata['Sampler'], context.dic_metadata['IniFilePath'], context.dic_metadata['StrucName']) insert_param_space(context.db_path, context.params_dict) insert_qois(context.db_path, context.qois_dict) except Exception as e: # Print traceback for debugging traceback.print_exc() context.logger.error(f"Error inserting data into the database: {e}") raise # Populate Samples table context.logger.info(f"Inserting samples into the database") insert_samples(context.db_path, context.dic_samples) else: exist_db = None All_Samples = None All_Replicates = None All_Parameters = None if use_mpi: exist_db = comm.bcast(exist_db, root=0) All_Samples = comm.bcast(All_Samples, root=0) All_Replicates = comm.bcast(All_Replicates, root=0) All_Parameters = comm.bcast(All_Parameters, root=0) # Number of parameters expected in the XML and rules params_xml = [param_name for param_name in PhysiCellModel.XML_parameters_variable.values()] params_rules = [param_name for param_name in PhysiCellModel.parameters_rules_variable.values()] # Generate a three list with size NumSimulations if not exist_db: if rank == 0: context.logger.info(f"Generating {len(context.dic_samples)*PhysiCellModel.numReplicates} simulations") for sampleID in context.dic_samples.keys(): for replicateID in np.arange(PhysiCellModel.numReplicates): All_Parameters.append(context.dic_samples[sampleID]) All_Samples.append(sampleID) All_Replicates.append(replicateID) else: # Three lists with size NumSimulations from check_existing_sa if rank == 0: context.logger.info(f"Generating {len(All_Samples)} simulations") ################################### # Running using concurrent.futures ################################### if use_futures: # Use concurrent.futures for parallel execution with ProcessPoolExecutor(max_workers=context.num_workers) as executor: context.futures = [] # Store futures in context for cancellation support for ind_sim in range(len(All_Samples)): if context.cancelled(): context.logger.info("Simulation cancelled before submitting all jobs.") return ParametersXML = {key: All_Parameters[ind_sim][key] for key in params_xml} if params_xml else np.array([]) ParametersRules = {key: All_Parameters[ind_sim][key] for key in params_rules} if params_rules else np.array([]) model_config = { 'ini_path': context.dic_metadata['IniFilePath'], 'struc_name': context.dic_metadata['StrucName'], } # Submit the job to the executor context.futures.append(executor.submit( run_replicate_serializable, PhysiCellModel_conf=model_config, sample_id=All_Samples[ind_sim], replicate_id=All_Replicates[ind_sim], ParametersXML=ParametersXML, ParametersRules=ParametersRules, qoi_functions=context.qois_dict, qoi_def=context.qoi_def, return_binary_output=True, #drop_columns, custom_summary_function=context.summary_function, )) # Use as_completed with a short timeout to avoid blocking when cancelled remaining_futures = list(context.futures) while remaining_futures and not context.cancelled(): try: # Use a short timeout to check cancellation frequently for future_done in as_completed(remaining_futures, timeout=0.5): remaining_futures.remove(future_done) if context.cancelled(): context.logger.info("Simulation cancelled during result collection.") break if future_done.cancelled(): context.logger.info("Future was cancelled, skipping result collection.") continue try: sample_id, replicate_id, result_data = future_done.result(timeout=0.5) context.logger.info(f"Writing to the database for Sample: {sample_id}, Replicate: {replicate_id}, Result size: {sys.getsizeof(result_data)/1024:.2f} KB") try: insert_output(context.db_path, sample_id, replicate_id, result_data) except Exception as e: context.logger.error(f"Error writing to the database: {e}") except TimeoutError: # Future is not done yet, will be picked up in the next iteration remaining_futures.append(future_done) context.logger.debug("Future not yet complete, will check again.") except Exception as e: context.logger.error(f"Error retrieving future result: {e}") except TimeoutError: # No futures completed within timeout, check cancellation and continue if context.cancelled(): context.logger.info("Simulation cancelled while waiting for futures to complete.") break # If cancellation was requested, exit the loop if context.cancelled(): context.logger.info("Breaking out of future collection loop due to cancellation.") break # If we cancelled, make sure all futures are cancelled if context.cancelled(): context.logger.info("Cancelling any remaining futures.") for future in remaining_futures: if not future.done() and not future.cancelled(): future.cancel() return ################################### # Running using MPI ################################### elif use_mpi: # Split simulations into ranks SplitIndexes = np.array_split(np.arange(len(All_Samples)), size, axis=0) context.logger.info(f"Rank {rank} assigned {len(SplitIndexes[rank])} simulations: indices {SplitIndexes[rank].tolist() if len(SplitIndexes[rank]) > 0 else []}") # Run simulations (MPI) for ind_sim in SplitIndexes[rank]: if context.cancelled(): context.logger.info(f"Rank {rank}: Simulation cancelled.") break ParametersXML = {key: All_Parameters[ind_sim][key] for key in params_xml} if params_xml else np.array([]) ParametersRules = {key: All_Parameters[ind_sim][key] for key in params_rules} if params_rules else np.array([]) try: if context.summary_function: result_data_nonserialized = PhysiCellModel.RunModel( All_Samples[ind_sim], All_Replicates[ind_sim], ParametersXML, ParametersRules, RemoveConfigFile=True, SummaryFunction=context.summary_function) result_data = pickle.dumps(result_data_nonserialized) else: _, _, result_data = run_replicate( PhysiCellModel=PhysiCellModel, sample_id=All_Samples[ind_sim], replicate_id=All_Replicates[ind_sim], ParametersXML=ParametersXML, ParametersRules=ParametersRules, qoi_functions=context.qois_dict, qoi_def=context.qoi_def, #return_binary_output, #drop_columns, #custom_summary_function, ) except Exception as e: context.logger.error( f"Rank {rank}: Error running simulation for Sample: {All_Samples[ind_sim]}, Replicate: {All_Replicates[ind_sim]}: {e}" ) continue # Write to database with retry logic context.logger.info(f"Rank {rank} writing to database for Sample: {All_Samples[ind_sim]}, Replicate: {All_Replicates[ind_sim]}") try: insert_output(context.db_path, All_Samples[ind_sim], All_Replicates[ind_sim], result_data) context.logger.info(f"Rank {rank} successfully wrote results for Sample: {All_Samples[ind_sim]}, Replicate: {All_Replicates[ind_sim]}") except Exception as e: context.logger.error(f"Rank {rank} failed to write results for Sample: {All_Samples[ind_sim]}, Replicate: {All_Replicates[ind_sim]}: {e}") raise comm.Barrier() MPI.Finalize() ################################### # Running sequentially ################################### else: context.logger.info(f"Rank {rank} assigned {len(All_Samples)} simulations.") # Run simulations sequentially for ind_sim in range(len(All_Samples)): if context.cancelled(): context.logger.info("Sequential simulation cancelled.") break ParametersXML = {key: All_Parameters[ind_sim][key] for key in params_xml} if params_xml else np.array([]) ParametersRules = {key: All_Parameters[ind_sim][key] for key in params_rules} if params_rules else np.array([]) if context.summary_function: result_data_nonserialized = PhysiCellModel.RunModel( All_Samples[ind_sim], All_Replicates[ind_sim], ParametersXML, ParametersRules, RemoveConfigFile=True, SummaryFunction=context.summary_function) result_data = pickle.dumps(result_data_nonserialized) else: _, _, result_data = run_replicate( PhysiCell_Model=PhysiCellModel, sample_id=All_Samples[ind_sim], replicate_id=All_Replicates[ind_sim], ParametersXML=ParametersXML, ParametersRules=ParametersRules, qoi_functions=context.qois_dict, qoi_def=context.qoi_def, #return_binary_output, #drop_columns, #custom_summary_function, ) # Write to the database directly (no locks or MPI synchronization needed) context.logger.info(f"Rank {rank} writing to the database for Sample: {All_Samples[ind_sim]}, Replicate: {All_Replicates[ind_sim]}") try: insert_output(context.db_path, All_Samples[ind_sim], All_Replicates[ind_sim], pickle.dumps(result_data)) context.logger.info(f"Rank {rank} finished writing to the database for Sample: {All_Samples[ind_sim]}, Replicate: {All_Replicates[ind_sim]}") except Exception as e: context.logger.error(f"Error inserting output into the database: {e}") raise if rank == 0: print(f"Simulations completed and results stored in the database: {context.db_path}.") # Disable WAL mode of the database to allow reading results without locks _disable_wal_mode(context.db_path)
if __name__ == "__main__": pass