habit.core.habitat_analysis.strategies.base_strategy 源代码

"""
Base strategy interface for habitat analysis.
"""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Dict, Any
from pathlib import Path

import pandas as pd

from habit.core.habitat_analysis.config_schemas import ResultColumns

if TYPE_CHECKING:
    from habit.core.habitat_analysis.habitat_analysis import HabitatAnalysis
    from habit.core.habitat_analysis.pipelines.base_pipeline import HabitatPipeline


def _canonical_csv_column_order(df: pd.DataFrame) -> List[str]:
    """
    Return column order for habitats.csv: metadata columns first (fixed order),
    then all other columns in their current order.
    This ensures habitats.csv has a consistent, predictable column order across runs.
    """
    # Fixed order for standard metadata columns (only include if present)
    meta_order = [
        ResultColumns.SUBJECT,
        ResultColumns.SUPERVOXEL,
        ResultColumns.HABITATS,
        ResultColumns.COUNT,
    ]
    meta_cols = [c for c in meta_order if c in df.columns]
    # Rest of columns (features) in their current order
    other_cols = [c for c in df.columns if c not in meta_cols]
    return meta_cols + other_cols


[文档] class BaseClusteringStrategy(ABC): """ Abstract base class for habitat analysis strategies. Each strategy should implement run() and return a results DataFrame. """
[文档] def __init__(self, analysis: "HabitatAnalysis"): """ Initialize the strategy with a HabitatAnalysis instance. Args: analysis: HabitatAnalysis instance with shared utilities and configuration """ self.analysis = analysis self.config = analysis.config self.logger = analysis.logger
def _get_attributes_to_update(self) -> Dict[str, Any]: """ Get mapping of attribute names to their values that should be updated in pipeline steps. This method can be overridden by subclasses to customize which attributes are updated. By default, it includes: - config: Updated to self.config - All manager attributes from self.analysis (automatically discovered) Returns: Dictionary mapping attribute names to their values """ attributes_to_update: Dict[str, Any] = { 'config': self.config, } # Automatically discover and update all manager attributes from analysis # This handles current managers (feature_manager, clustering_manager, result_manager) # and any future managers that follow the naming convention for attr_name in dir(self.analysis): # Match attributes ending with '_manager' (e.g., feature_manager, data_manager) # Exclude private attributes (starting with '_') if attr_name.endswith('_manager') and not attr_name.startswith('_'): manager = getattr(self.analysis, attr_name, None) if manager is not None: attributes_to_update[attr_name] = manager return attributes_to_update def _select_pipeline_config(self, pipeline: "HabitatPipeline") -> Any: """ Choose which config should be applied to a loaded pipeline. In predict mode, users may provide a minimal YAML without FeatureConstruction. In that case, we keep the pipeline's trained config (which includes the FeatureConstruction details) and only override runtime-safe fields like output paths and logging/plot flags. Args: pipeline: Loaded HabitatPipeline instance Returns: Configuration object to apply to the pipeline and its steps """ if ( self.config.run_mode == 'predict' and self.config.FeatureConstruction is None and pipeline.config is not None ): pipeline_config = pipeline.config # Override only runtime fields that should follow the predict config. pipeline_config.out_dir = self.config.out_dir pipeline_config.plot_curves = self.config.plot_curves pipeline_config.save_results_csv = self.config.save_results_csv pipeline_config.save_images = self.config.save_images pipeline_config.processes = self.config.processes pipeline_config.random_state = self.config.random_state pipeline_config.verbose = self.config.verbose pipeline_config.debug = self.config.debug return pipeline_config return self.config def _sync_feature_manager( self, pipeline_feature_manager: Any, runtime_feature_manager: Any ) -> None: """ Synchronize a loaded pipeline's FeatureManager with current runtime paths. We keep the trained FeatureManager instance (it contains the fitted feature extraction configuration) and only update data paths and logging targets so it can operate on the current dataset and write logs consistently. Args: pipeline_feature_manager: FeatureManager from the loaded pipeline runtime_feature_manager: FeatureManager created for the current run """ if ( getattr(runtime_feature_manager, "images_paths", None) is not None and getattr(runtime_feature_manager, "mask_paths", None) is not None ): pipeline_feature_manager.set_data_paths( runtime_feature_manager.images_paths, runtime_feature_manager.mask_paths ) if hasattr(runtime_feature_manager, "_log_file_path"): pipeline_feature_manager.set_logging_info( runtime_feature_manager._log_file_path, runtime_feature_manager._log_level ) def _update_pipeline_references(self, pipeline: "HabitatPipeline") -> None: """ Update references in loaded pipeline to use current analysis instances. This method automatically updates config and manager references in the pipeline and all its steps to use the current analysis instances. This ensures that config changes (like out_dir, plot_curves) are reflected in all steps. The method is highly extensible: 1. Automatically discovers all manager attributes (ending with '_manager') 2. Can be extended by overriding _get_attributes_to_update() in subclasses 3. Supports any future managers without code changes (as long as they follow naming convention) Examples of automatically discovered attributes: - feature_manager: Updated to self.analysis.feature_manager - clustering_manager: Updated to self.analysis.clustering_manager - result_manager: Updated to self.analysis.result_manager - data_manager: Updated to self.analysis.data_manager (if added in future) - cache_manager: Updated to self.analysis.cache_manager (if added in future) Args: pipeline: Loaded HabitatPipeline instance to update """ # Update pipeline-level config with predict-safe selection config_to_apply = self._select_pipeline_config(pipeline) pipeline.config = config_to_apply # Get attributes to update (can be customized by subclasses) attributes_to_update = self._get_attributes_to_update() attributes_to_update['config'] = config_to_apply # Update all steps in the pipeline for _, step in pipeline.steps: for attr_name, attr_value in attributes_to_update.items(): if hasattr(step, attr_name): if attr_name == 'feature_manager': # Keep trained FeatureManager; only sync runtime paths/logging. pipeline_feature_manager = getattr(step, attr_name, None) if pipeline_feature_manager is not None: self._sync_feature_manager(pipeline_feature_manager, attr_value) continue setattr(step, attr_name, attr_value)
[文档] def run( self, subjects: Optional[List[str]] = None, save_results_csv: Optional[bool] = None, load_from: Optional[str] = None ) -> pd.DataFrame: """ Template method for executing the strategy. This method defines the algorithm skeleton. Subclasses can override specific steps if needed, but most will only need to implement strategy-specific logic in hooks. Args: subjects: List of subjects to process (None means all subjects) save_results_csv: Whether to save results to CSV (defaults to config.save_results_csv) load_from: Optional path to a saved pipeline. If provided, the pipeline is loaded and only transform() is executed. Returns: Results DataFrame """ # Use config value if parameter not provided, allowing runtime override if save_results_csv is None: save_results_csv = self.config.save_results_csv subjects = self._prepare_subjects(subjects) X = self._build_input(subjects) pipeline_path = self._resolve_pipeline_path(load_from) # Ensure output directory exists Path(self.config.out_dir).mkdir(parents=True, exist_ok=True) if load_from: self._run_predict_mode(pipeline_path, X) else: self._run_train_mode(pipeline_path, X) # Post-process results (hook for strategy-specific logic) self._post_process_results() # Update ResultManager with new results self.analysis.result_manager.results_df = self.analysis.results_df # Save results if save_results_csv: self._save_results() return self.analysis.results_df
def _run_predict_mode(self, pipeline_path: Path, X: Dict[str, Dict]) -> None: """ Run pipeline in predict mode (load from file). Args: pipeline_path: Path to saved pipeline X: Input data dict """ from ..pipelines.base_pipeline import HabitatPipeline if self.config.verbose: strategy_name = self._get_strategy_name() self.logger.info(f"Loading and running {strategy_name} pipeline...") if not pipeline_path.exists(): raise FileNotFoundError( f"Saved pipeline not found at {pipeline_path}. " "Provide a valid load_from path or run without load_from to train." ) # Load pipeline self.pipeline = HabitatPipeline.load(str(pipeline_path)) # Update references in loaded pipeline to use current analysis instances self._update_pipeline_references(self.pipeline) # Disable image outputs and plots for prediction runs to avoid unnecessary I/O self.pipeline.config.plot_curves = False # Transform self.analysis.results_df = self.pipeline.transform(X) def _run_train_mode(self, pipeline_path: Path, X: Dict[str, Dict]) -> None: """ Run pipeline in train mode (build and fit). Args: pipeline_path: Path to save trained pipeline X: Input data dict """ from ..pipelines.pipeline_builder import build_habitat_pipeline if self.config.verbose: strategy_name = self._get_strategy_name() self.logger.info(f"Building and fitting {strategy_name} pipeline...") # Build new pipeline self.pipeline = build_habitat_pipeline( config=self.config, feature_manager=self.analysis.feature_manager, clustering_manager=self.analysis.clustering_manager, result_manager=self.analysis.result_manager ) # Fit and transform self.analysis.results_df = self.pipeline.fit_transform(X) # Save pipeline if self.config.verbose: self.logger.info(f"Saving fitted pipeline to {pipeline_path}") self.pipeline.save(str(pipeline_path)) def _get_strategy_name(self) -> str: """ Get human-readable strategy name for logging. Returns: Strategy name (e.g., "One-Step", "Two-Step", "Direct Pooling") """ # Default implementation: extract from class name class_name = self.__class__.__name__ if 'OneStep' in class_name: return "One-Step" elif 'TwoStep' in class_name: return "Two-Step" elif 'DirectPooling' in class_name or 'Pooling' in class_name: return "Direct Pooling" return class_name.replace('Strategy', '') def _prepare_subjects(self, subjects: Optional[List[str]]) -> List[str]: """ Normalize subject list and validate it is not empty. Args: subjects: Optional list of subject IDs Returns: List of subject IDs """ if subjects is None: subjects = list(self.analysis.images_paths.keys()) if not subjects: strategy_name = self._get_strategy_name() raise ValueError(f"No subjects provided for {strategy_name} strategy.") return list(subjects) def _build_input(self, subjects: List[str]) -> Dict[str, Dict]: """ Build input dict for the pipeline. Args: subjects: List of subject IDs Returns: Dict of subject_id -> empty dict (pipeline will populate data) """ return {subject: {} for subject in subjects} def _resolve_pipeline_path(self, load_from: Optional[str]) -> Path: """ Resolve pipeline path for saving or loading. Args: load_from: Optional path to a saved pipeline Returns: Path to pipeline file """ if load_from: return Path(load_from) return Path(self.config.out_dir) / "habitat_pipeline.pkl" def _post_process_results(self) -> None: """ Post-process results after pipeline execution. Hook for strategy-specific result processing. Subclasses can override this method to add custom logic (e.g., column renaming, validation). Default implementation does nothing. """ pass def _save_results(self) -> None: """ Save results for the strategy. Hook for strategy-specific result saving. Subclasses can override this method to customize saving behavior. Default implementation saves CSV and habitat images. """ if self.config.verbose: self.logger.info("Saving results...") # Save results CSV with consistent column order (metadata first, then features) csv_path = Path(self.config.out_dir) / "habitats.csv" df = self.analysis.results_df canonical_order = _canonical_csv_column_order(df) df[canonical_order].to_csv(str(csv_path), index=False) if self.config.verbose: self.logger.info(f"Results saved to {csv_path}") # Save habitat images for each subject if self.config.save_images: # Sync mask cache from pipeline (populated in main process) to result manager. if ( hasattr(self, "pipeline") and hasattr(self.pipeline, "mask_info_cache") and self.pipeline.mask_info_cache ): self.analysis.result_manager.mask_info_cache = self.pipeline.mask_info_cache self.analysis.result_manager.save_all_habitat_images(failed_subjects=[])