Source code for uq_physicell.bo.plots

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# My local modules
from .utils import normalize_params_df, get_observed_qoi
from ..database.bo_db import load_structure

[docs] def plot_parameter_space(df_samples:pd.DataFrame, df_param_space:pd.DataFrame, params:dict=None, real_value:dict=None, axis=None): """ Plot the parameter space from the samples DataFrame. Args: df_samples: DataFrame containing the samples. df_param_space: DataFrame defining the search space for each parameter. params: Dictionary with parameter names as keys and their best values as values (optional). real_value: Dictionary with real parameter values to plot (optional). axis: Matplotlib axis to plot on (optional). Returns: Matplotlib figure and axis if axis is None, otherwise draw in the given axis. """ # Normalize the parameter space df_plot = normalize_params_df(df_samples, df_param_space) # Plotting if axis is None: fig, ax = plt.subplots(figsize=(10, 6)) else: ax = axis sns.scatterplot(df_plot, y='ParamName', x='ParamValue', hue='SampleID', legend=True, ax=ax) # Get existing legend handles and labels for Sample ID sample_handles, sample_labels = ax.get_legend_handles_labels() # Create separate lists for special markers special_handles = [] special_labels = [] if real_value: real_value_df = pd.DataFrame(real_value.items(), columns=['ParamName', 'ParamValue']) real_value_df['SampleID'] = 0 # Add SampleID as 0 for real values real_value_norm = normalize_params_df(real_value_df, df_param_space) real_scatter = ax.scatter(real_value_norm['ParamValue'], real_value_norm['ParamName'], color='blue', label='Real Value', marker='*', s=100, zorder=5) special_handles.append(real_scatter) special_labels.append('Real Value') if params: for key, param in params.items(): param_df = pd.DataFrame(param.items(), columns=['ParamName', 'ParamValue']) param_df['SampleID'] = key # Add SampleID as key for best parameters param_norm = normalize_params_df(param_df, df_param_space) param_scatter = ax.scatter(param_norm['ParamValue'], param_norm['ParamName'], marker='x', s=100, zorder=5) special_handles.append(param_scatter) special_labels.append(key) # Create legends dynamically based on what's available if sample_handles and special_handles: # Both sample data and special markers exist - create two legends legend1 = ax.legend(sample_handles, sample_labels, loc='upper right', bbox_to_anchor=(1.2, 1), fontsize='small', title='Sample ID') legend2 = ax.legend(special_handles, special_labels, loc='upper right', bbox_to_anchor=(1.2, 0.6), fontsize='small', title='Ref. Points') ax.add_artist(legend1) elif sample_handles: # Only sample data exists - single legend ax.legend(sample_handles, sample_labels, loc='upper right', bbox_to_anchor=(1.2, 1), fontsize='small', title='Sample ID') elif special_handles: # Only special markers exist - single legend ax.legend(special_handles, special_labels, loc='upper right', bbox_to_anchor=(1.2, 1), fontsize='small', title='Ref. Points') ax.set_title('Parameter Space') ax.set_xlabel('Normalized Parameter Value') ax.set_ylabel('') if axis is None: plt.tight_layout() return fig, ax
[docs] def plot_parameter_space_db(db_file:str, params:dict=None, real_value:dict=None, axis=None): """ Plot the parameter space from the database file. Args: db_file: Path to the database file. params: Dictionary with parameter names as keys and their best values as values (optional). real_value: Dictionary with real parameter values to plot (optional). axis: Matplotlib axis to plot on (optional). Returns: Matplotlib figure and axis if axis is None, otherwise draw in the given axis. """ # Load the database structure df_metadata, df_param_space, df_qois, df_gp_models, df_samples, df_output = load_structure(db_file) return plot_parameter_space(df_samples, df_param_space, params, real_value, axis)
[docs] def plot_parameter_vs_fitness(df_samples:pd.DataFrame, df_output:pd.DataFrame, parameter_name:str, qoi_name:str, samples_id=None, axis=None): """ Plot the parameter values against the fitness values. Args: df_samples: DataFrame containing the samples. df_output: DataFrame containing the output of the analysis. parameter_name: Name of the parameter to plot. qoi_name: Name of the QoI to plot against the parameter. samples_id: List of sample IDs to highlight (optional). axis: Matplotlib axis to plot on (optional). Returns: Matplotlib figure and axis if axis is None, otherwise draw in the given axis. """ # Sort the parameter values df_sorted_params = df_samples[df_samples['ParamName'] == parameter_name].sort_values(by='ParamValue').reset_index() # Find the corresponding fitness values for the sorted SampleIDs df_sorted_fitness = df_output.set_index('SampleID').loc[df_sorted_params['SampleID']] df_sorted_fitness = df_sorted_fitness.reset_index() for sample_id in df_sorted_fitness['SampleID']: df_sorted_fitness.loc[df_sorted_fitness['SampleID'] == sample_id, 'ObjFunc'] = df_sorted_fitness.loc[df_sorted_fitness['SampleID'] == sample_id, 'ObjFunc'].values[0][qoi_name] # print(f"Sorted Parameters:\n{df_sorted_params}") # print(f"Sorted Objectives:\n{df_sorted_objectives}") # Plotting if axis is None: fig, ax = plt.subplots(figsize=(10, 6)) ax.set_title(f"Parameter vs Objective: {parameter_name} vs {qoi_name}") else: ax = axis if samples_id: # Plot non-selected samples in gray df_non_selected = df_sorted_params[~df_sorted_params['SampleID'].isin(samples_id)].copy() # Merge with fitness values using SampleID df_non_selected = df_non_selected.merge(df_sorted_fitness[['SampleID', 'ObjFunc']], on='SampleID', how='left') # print(f"Non-selected samples with fitness:\n{df_non_selected.head()}") ax.scatter(df_non_selected['ParamValue'], df_non_selected['ObjFunc'], marker='o', c='gray', zorder=1) # Plot selected samples df_selected = df_sorted_params[df_sorted_params['SampleID'].isin(samples_id)].copy() # Merge with fitness values using SampleID df_selected = df_selected.merge(df_sorted_fitness[['SampleID', 'ObjFunc']], on='SampleID', how='left') sns.scatterplot(df_selected, x='ParamValue', y='ObjFunc', hue='SampleID', ax=ax, marker='X', zorder=2, palette="deep", s=100) # ax.scatter(df_selected['ParamValue'], df_sorted_fitness.loc[df_sorted_fitness['SampleID'].isin(samples_id), 'ObjFunc'], c='red', label='Selected Samples', marker='x', zorder=2) else: # Plot all samples in gray ax.scatter(df_sorted_params['ParamValue'], df_sorted_fitness['ObjFunc'], marker='o', c='gray', zorder=1) ax.set_xlabel(parameter_name) ax.set_ylabel(f"Fitness({qoi_name})") if axis is None: plt.tight_layout() return fig, ax
[docs] def plot_pareto_front(df_output:pd.DataFrame, qoi_name1:str, qoi_name2:str, samples_id=None, axis=None, plot_std=False): """ Plot the Pareto front of the fitness values. Args: df_output: DataFrame containing the output of the analysis. qoi_name1: Name of the QoI to plot in x axis qoi_name2: Name of the QoI to plot in y axis samples_id: List of sample IDs to highlight (optional). axis: Matplotlib axis to plot on (optional). Returns: Matplotlib figure and axis if axis is None, otherwise draw in the given axis. """ # Plotting if axis is None: fig, ax = plt.subplots(figsize=(10, 6)) ax.set_title(f"Pareto Front: {qoi_name1} vs {qoi_name2}") else: ax = axis # Find the corresponding fitness values for the sorted SampleIDs for sample_id in df_output['SampleID']: df_output.loc[df_output['SampleID'] == sample_id, qoi_name1] = df_output.loc[df_output['SampleID'] == sample_id, 'ObjFunc'].values[0][qoi_name1] df_output.loc[df_output['SampleID'] == sample_id, qoi_name2] = df_output.loc[df_output['SampleID'] == sample_id, 'ObjFunc'].values[0][qoi_name2] if plot_std: df_output.loc[df_output['SampleID'] == sample_id, f"{qoi_name1}_std"] = df_output.loc[df_output['SampleID'] == sample_id, 'Noise_Std'].values[0][qoi_name1] df_output.loc[df_output['SampleID'] == sample_id, f"{qoi_name2}_std"] = df_output.loc[df_output['SampleID'] == sample_id, 'Noise_Std'].values[0][qoi_name2] # Plot non-selected samples in gray df_dominated = df_output[~df_output['SampleID'].isin(samples_id)].copy() df_nondominated = df_output[df_output['SampleID'].isin(samples_id)].copy() sns.scatterplot(df_nondominated, x=qoi_name1, y=qoi_name2, hue='SampleID', ax=ax, marker='X', zorder=2, palette="deep", s=100) sns.scatterplot(df_dominated, x=qoi_name1, y=qoi_name2, ax=ax, marker='o', zorder=1, color='gray', s=100, label='Dominated', legend=False) if plot_std: # Match each errorbar color to the corresponding hue color from seaborn's deep palette. hue_order = list(df_nondominated['SampleID'].drop_duplicates()) palette = sns.color_palette("deep", n_colors=len(hue_order)) color_map = dict(zip(hue_order, palette)) for _, row in df_nondominated.iterrows(): ax.errorbar( row[qoi_name1], row[qoi_name2], xerr=row[f"{qoi_name1}_std"], yerr=row[f"{qoi_name2}_std"], fmt='none', ecolor=color_map[row['SampleID']], capsize=5, alpha=0.5, zorder=1 ) # Plot the Reference Point ( the fitness of the Observed Data is (1,1) after normalization) ax.axvline(x=1, color='red', linestyle='--', alpha=0.3) ax.axhline(y=1, color='red', linestyle='--', alpha=0.3) ax.scatter(1, 1, color='red', marker='*', s=200, label='Obs. Data', zorder=3) # Combine all legend handles and labels handles, labels = ax.get_legend_handles_labels() # Modify labels to have cleaner names new_labels = [f'Sample ID: {label}' if label not in ['Dominated', 'Obs. Data'] else label for label in labels] nondominated_handles = [ ax.scatter([], [], marker='X', s=100, color='gray', edgecolors='gray', linewidth=0.2, label='Non-Dominated', zorder=2) ] ax.legend(nondominated_handles+handles, ['Non-Dominated'] + new_labels) ax.set_xlabel(qoi_name1) ax.set_ylabel(qoi_name2) if axis is None: plt.tight_layout() return fig, ax
[docs] def plot_parameter_vs_fitness_db(db_file:str, parameter_name:str, qoi_name:str, axis=None): """ Plot the parameter space against the fitness values from the database file. Args: db_file: Path to the database file. parameter_name: Name of the parameter to plot. qoi_name: Name of the QoI to plot against the parameter. axis: Matplotlib axis to plot on (optional). Returns: Matplotlib figure and axis if axis is None, otherwise draw in the given axis. """ # Load the database structure df_metadata, df_param_space, df_qois, df_gp_models, df_samples, df_output = load_structure(db_file) return plot_parameter_vs_fitness(df_samples, df_output, parameter_name, qoi_name, axis)
[docs] def plot_qoi_param(df_ObsData:pd.DataFrame, df_output:pd.DataFrame, samples_id:list, x_var: str, y_var:str, axis=None, swarmplot=False, plot_residuals=False): """ Plot the QoI parameter space from the database file. Args: df_ObsData: Observed QoI DataFrame. df_output: Output DataFrame. samples_id: List of Sample IDs to plot. x_var: Variable to plot on the x-axis. y_var: Variable to plot on the y-axis. axis: Matplotlib axis to plot on (optional). swarmplot: Whether to use swarmplot instead of scatterplot. plot_residuals: Whether to plot residuals. Returns: Matplotlib figure and axis if axis is None, otherwise draw in the given axis. """ # Load the model results associated with the parameters selected_outputs = df_output[df_output['SampleID'].isin(samples_id)] print(f"Sample ID: {samples_id}") print(f"Objective Function Values:\n{selected_outputs['ObjFunc'].values[0]}") print(f"Noise of Objective Function:\n{selected_outputs['Noise_Std'].values[0]}") if axis is None: fig, ax = plt.subplots(figsize=(10, 6)) else: ax = axis # Plot observed data if available if not plot_residuals: if df_ObsData[y_var].nunique() > 1: # Ensure there are multiple y values to plot sns.lineplot(df_ObsData, x=x_var, y=y_var, color='red', label='Observed QoI', linewidth=3, ax=ax) else: # sns.scatterplot(df_ObsData, x=x_var, y=y_var, color='red', label='Observed QoI', ax=ax) ax.axhline(y=df_ObsData[y_var].dropna().values[0], color='red', label='Observed QoI', linewidth=3) else: ax.axhline(y=0.0, color='k', linestyle='--', label=None) all_df_data = pd.DataFrame() # Plot each QoI against the model results associated with the dic_param for sample_id in samples_id: dic_data = df_output[df_output['SampleID'] == sample_id]['Data'].values[0] for rep_id, output in dic_data.items(): df_data = pd.DataFrame(output, columns=[x_var, y_var]) if plot_residuals: if df_ObsData[y_var].nunique() > 1: # Ensure there are multiple y values to plot df_data_indexed = df_data.set_index(x_var)[y_var] obs_data_indexed = df_ObsData.set_index(x_var)[y_var].reindex(df_data_indexed.index) df_data[y_var] = (df_data_indexed - obs_data_indexed).values else: df_data[y_var] = df_data[y_var] - df_ObsData[y_var].dropna().values[0] df_data['SampleID'] = sample_id df_data['replicateID'] = rep_id if all_df_data.empty: all_df_data = df_data.copy() else: all_df_data = pd.concat([all_df_data, df_data], ignore_index=True) # Plot PhysiCell replicates with only one legend entry using seaborn # Add formatted SampleID for better legend display all_df_data['SampleID_formatted'] = all_df_data['SampleID'].apply(lambda x: f'SampleID: {x}') if not all_df_data.empty: if df_ObsData[y_var].nunique() > 1: # Ensure there are multiple y values to plot sns.lineplot(data=all_df_data, x=x_var, y=y_var, ax=ax, hue='SampleID_formatted', units='replicateID', dashes=(4,2), estimator=None) else: if not swarmplot: sns.scatterplot(data=all_df_data, x=x_var, y=y_var, ax=ax, hue='SampleID_formatted', s=50) else: sns.swarmplot(data=all_df_data, y=y_var, ax=ax, hue='SampleID_formatted') if plot_residuals: ax.set_ylabel(f"Residual of {y_var}") else: ax.set_xlabel(x_var) ax.set_ylabel(y_var) ax.legend() if axis is None: plt.tight_layout() return fig, ax
[docs] def plot_qoi_param_db(db_file:str, samples_id:list, x_var: str=None, y_var:str=None, axis=None): """ Plot the QoI parameter space from the database file. Args: db_file: Path to the database file. samples_id: List of Sample IDs to plot. x_var: Variable to plot on the x-axis (optional). y_var: Variable to plot on the y-axis (optional). axis: Matplotlib axis to plot on (optional). Returns: Matplotlib figure and axis if axis is None, otherwise draw in the given axis. """ # Load the database structure df_metadata, df_param_space, df_qois, df_gp_models, df_samples, df_output = load_structure(db_file) df_ObsData = get_observed_qoi(df_metadata['ObsData_Path'].values[0], df_qois) return plot_qoi_param(df_ObsData, df_output, samples_id, x_var, y_var, axis)