machine_learning 模块

Machine Learning module for HABIT package.

This module aggregates key components for ML workflows, including model training, evaluation, and feature selection.

工作流 (Workflows)

工作流类封装了完整的机器学习过程,包括训练、验证和预测。

Standard Machine Learning Workflow (Train/Test Split). Inherits from BaseWorkflow for consistent infrastructure.

class habit.core.machine_learning.workflows.holdout_workflow.MachineLearningWorkflow(config: MLConfig)[源代码]

基类:BaseWorkflow

__init__(config: MLConfig)[源代码]

Initialize BaseWorkflow.

参数:
  • config -- MLConfig Pydantic object or dict (dict will be validated and converted).

  • module_name -- Name of the workflow module.

run_pipeline()[源代码]

Main entry point to be implemented by subclasses.

Advanced K-Fold Cross-Validation Workflow. Inherits from BaseWorkflow for consistent infrastructure.

class habit.core.machine_learning.workflows.kfold_workflow.MachineLearningKFoldWorkflow(config: MLConfig)[源代码]

基类:BaseWorkflow

__init__(config: MLConfig)[源代码]

Initialize BaseWorkflow.

参数:
  • config -- MLConfig Pydantic object or dict (dict will be validated and converted).

  • module_name -- Name of the workflow module.

run_pipeline()[源代码]

Main entry point to be implemented by subclasses.

Prediction Workflow Module Handles loading trained models and making predictions on new data.

class habit.core.machine_learning.workflows.prediction_workflow.PredictionWorkflow(config: PredictionConfig, logger: Logger | None = None)[源代码]

基类:object

Workflow for running model predictions.

config

Configuration object.

Type:

PredictionConfig

logger

Logger instance.

Type:

logging.Logger

__init__(config: PredictionConfig, logger: Logger | None = None)[源代码]

Initialize the prediction workflow.

参数:
  • config -- Prediction configuration.

  • logger -- Logger instance.

run_pipeline() None[源代码]

Run the full prediction pipeline.

class habit.core.machine_learning.base_workflow.BaseWorkflow(config: MLConfig | Dict[str, Any], module_name: str)[源代码]

基类:ABC

Abstract Base Class for all Machine Learning Workflows. Handles infrastructure like logging, data loading, and basic results persistence.

__init__(config: MLConfig | Dict[str, Any], module_name: str)[源代码]

Initialize BaseWorkflow.

参数:
  • config -- MLConfig Pydantic object or dict (dict will be validated and converted).

  • module_name -- Name of the workflow module.

abstract run_pipeline()[源代码]

Main entry point to be implemented by subclasses.

模型工厂 (Model Factory)

Model Factory Factory class for creating model instances

class habit.core.machine_learning.models.factory.ModelFactory[源代码]

基类:object

Factory class for creating model instances

classmethod register(name: str)[源代码]

Register a model class

参数:

name -- Model name

返回:

Decorator function

classmethod create_model(model_name: str, config: Dict[str, Any] | None = None) BaseModel[源代码]

Create a model instance

参数:
  • model_name -- Name of model to create

  • config -- Configuration dictionary

返回:

Model instance

返回类型:

BaseModel

抛出:

ValueError -- If model name is not registered

classmethod get_available_models() List[str][源代码]

Get list of available model names

返回:

List of registered model names

返回类型:

List[str]

Ensemble Model Wrapper for K-Fold Cross Validation. Allows treating a collection of K-fold models as a single scikit-learn estimator.

class habit.core.machine_learning.models.ensemble.HabitEnsembleModel(estimators: List[Any], voting: str = 'soft')[源代码]

基类:BaseEstimator, ClassifierMixin

An ensemble wrapper that aggregates predictions from multiple fitted pipelines. Used primarily to wrap K-Fold cross-validation results into a single predict-ready object.

estimators

List of fitted scikit-learn pipelines/models.

Type:

List[Any]

voting

'soft' (average probabilities) or 'hard' (majority vote). Default 'soft'.

Type:

str

__init__(estimators: List[Any], voting: str = 'soft')[源代码]
fit(X, y=None)[源代码]

No-op fit method. We assume the estimators passed in __init__ are ALREADY fitted. This allows us to wrap the K-Fold results directly.

predict_proba(X)[源代码]

Predict class probabilities for X.

返回:

array-like of shape (n_samples, n_classes)

predict(X)[源代码]

Predict class labels for X.

property classes_

Delegate classes_ attribute to the first estimator.

set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') HabitEnsembleModel

Configure whether metadata should be requested to be passed to the score method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to score.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

参数:

sample_weight (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) -- Metadata routing for sample_weight parameter in score.

返回:

self -- The updated object.

返回类型:

object

评估工具 (Evaluation)

Model Evaluation Module Provides functions for model training, evaluation, and result analysis

class habit.core.machine_learning.evaluation.model_evaluation.ModelEvaluator(output_dir: str)[源代码]

基类:object

__init__(output_dir: str)[源代码]

Initialize the model evaluator.

参数:

output_dir (str) -- Directory where evaluation results and plots will be saved

evaluate(model: Any, X: DataFrame, y: Series, dataset_name: str = 'test') Dict[str, Any][源代码]

Evaluate a single model on a single dataset.

参数:
  • model (Any) -- Trained model with predict and predict_proba methods

  • X (pd.DataFrame) -- Feature data

  • y (pd.Series) -- Label data

  • dataset_name (str) -- Name of the dataset

返回:

Dictionary containing evaluation results

返回类型:

Dict[str, Any]

plot_curves(model_data: Dict[str, Dict[str, Dict[str, List]]], curve_type: Literal['roc', 'dca', 'calibration', 'pr', 'all'] = 'all', title: str = 'evaluation', output_dir: str | None = None, prefix: str = '', n_bins: int = 10) Dict[str, str][源代码]

Plot various evaluation curves using methods from plotting.py

参数:
  • model_data (Dict) -- Dictionary containing model evaluation results Format: {dataset_name: {model_name: {'y_true': [...], 'y_pred_proba': [...]}}}

  • curve_type (Literal) -- Type of curve to plot, can be 'roc', 'dca', 'calibration', 'pr', or 'all'

  • title (str) -- Keyword for chart title

  • output_dir (Optional[str]) -- Output directory, defaults to self.output_dir

  • prefix (str) -- Prefix for output filenames

  • n_bins (int) -- Number of bins for calibration curve

返回:

Dictionary containing paths to generated chart files

返回类型:

Dict[str, str]

compare_models(test_data: Dict[str, Tuple[List, List]]) None[源代码]

Compare the performance of multiple models (using DeLong test)

参数:

test_data (Dict[str, Tuple[List, List]]) -- Test data dictionary, keys are model names, values are (y_true, y_pred_proba) tuples

class habit.core.machine_learning.evaluation.model_evaluation.MultifileEvaluator(output_dir: str)[源代码]

基类:object

__init__(output_dir: str) None[源代码]

初始化多文件评估器

参数:

output_dir (str) -- 图表输出目录

read_prediction_files(files_config: List[Dict]) MultifileEvaluator[源代码]

从多个文件读取预测结果

参数:

files_config (List[Dict]) -- 文件配置列表,每个元素包含: - path: 文件路径 - model_name: 模型名称 - subject_id_col: 受试者ID列名 - label_col: 真实标签列名 - prob_col: 预测概率列名 - pred_col: 预测标签列名(可选)

返回:

自身实例,用于方法链式调用

返回类型:

MultifileEvaluator

save_merged_data(filename: str = 'merged_predictions.csv') None[源代码]

保存合并后的数据到CSV文件

参数:

filename (str) -- 输出文件名

plot_roc(save_name: str = 'ROC.pdf', title: str = 'evaluation') None[源代码]

为所有模型绘制ROC曲线

参数:
  • save_name (str) -- 保存文件名

  • title (str) -- 图表标题

plot_dca(save_name: str = 'DCA.pdf', title: str = 'evaluation') None[源代码]

为所有模型绘制决策曲线分析(DCA)

参数:
  • save_name (str) -- 保存文件名

  • title (str) -- 图表标题

plot_calibration(save_name: str = 'Calibration.pdf', n_bins: int = 5, title: str = 'evaluation') None[源代码]

为所有模型绘制校准曲线

参数:
  • save_name (str) -- 保存文件名

  • n_bins (int) -- 校准曲线的分箱数

  • title (str) -- 图表标题

plot_pr_curve(save_name: str = 'PR_curve.pdf', title: str = 'evaluation') None[源代码]

为所有模型绘制精确率-召回率曲线

参数:
  • save_name (str) -- 保存文件名

  • title (str) -- 图表标题

run_delong_test(output_json: str | None = 'delong_test_results.json') List[Dict][源代码]

对所有模型对执行DeLong检验

参数:

output_json (Optional[str]) -- 输出JSON文件名,如不需要保存设为None

返回:

DeLong检验结果列表

返回类型:

List[Dict]

Metrics Calculation Module Provides a registry-based system for calculating model evaluation metrics.

Optimizations: - Confusion matrix caching for performance - Extended target metrics support (PPV, NPV, F1, etc.) - Fallback mechanism when no threshold satisfies all targets - Pareto optimal threshold selection - Category-based metric filtering

class habit.core.machine_learning.evaluation.metrics.MetricsCache(y_true: ndarray, y_pred: ndarray, y_prob: ndarray)[源代码]

基类:object

Cache for confusion matrix and derived metrics to avoid repeated calculations. Provides ~8x performance improvement when calculating multiple metrics.

__init__(y_true: ndarray, y_pred: ndarray, y_prob: ndarray)[源代码]
property confusion_matrix: ndarray

Lazy evaluation with caching of confusion matrix.

get_metric(metric_name: str, calculator: Callable) float[源代码]

Get cached metric or calculate and cache it.

habit.core.machine_learning.evaluation.metrics.register_metric(name: str, display_name: str, category: str = 'basic')[源代码]

Decorator to register a metric function.

参数:
  • name -- Internal unique key for the metric

  • display_name -- Pretty name for reports and plots

  • category -- 'basic', 'statistical', or 'clinical'

habit.core.machine_learning.evaluation.metrics.calc_accuracy(y_true, y_pred, y_prob, cm=None)[源代码]
habit.core.machine_learning.evaluation.metrics.calc_sensitivity(y_true, y_pred, y_prob, cm=None)[源代码]

Calculate sensitivity (recall, true positive rate).

habit.core.machine_learning.evaluation.metrics.calc_specificity(y_true, y_pred, y_prob, cm=None)[源代码]

Calculate specificity (true negative rate).

habit.core.machine_learning.evaluation.metrics.calc_ppv(y_true, y_pred, y_prob, cm=None)[源代码]

Calculate Positive Predictive Value (precision).

habit.core.machine_learning.evaluation.metrics.calc_npv(y_true, y_pred, y_prob, cm=None)[源代码]

Calculate Negative Predictive Value.

habit.core.machine_learning.evaluation.metrics.calc_f1(y_true, y_pred, y_prob, cm=None)[源代码]

Calculate F1-score (harmonic mean of precision and recall).

habit.core.machine_learning.evaluation.metrics.calc_auc(y_true, y_pred, y_prob)[源代码]
habit.core.machine_learning.evaluation.metrics.calc_hl_p(y_true, y_pred, y_prob)[源代码]
habit.core.machine_learning.evaluation.metrics.calc_spiegelhalter_p(y_true, y_pred, y_prob)[源代码]
habit.core.machine_learning.evaluation.metrics.calculate_metrics(y_true: ndarray | PredictionContainer, y_pred: ndarray | None = None, y_pred_proba: ndarray | None = None, categories: List[str] | None = None, use_cache: bool = True) Dict[str, float][源代码]

Calculate all registered metrics using PredictionContainer.

参数:
  • y_true -- True labels or PredictionContainer

  • y_pred -- Predicted labels (optional if using PredictionContainer)

  • y_pred_proba -- Predicted probabilities (optional if using PredictionContainer)

  • categories -- Filter metrics by category, e.g., ['basic', 'statistical'] If None, calculate all metrics

  • use_cache -- If True, use confusion matrix caching for better performance

返回:

Dictionary of metric_name -> value

habit.core.machine_learning.evaluation.metrics.apply_threshold(y_true: ndarray, y_pred_proba: ndarray, threshold: float) Dict[str, float][源代码]

Apply a given threshold to predicted probabilities and calculate metrics.

habit.core.machine_learning.evaluation.metrics.calculate_metrics_youden(y_true: ndarray, y_pred_proba: ndarray) Dict[str, float | Dict[str, float]][源代码]

Calculate metrics based on the optimal Youden index.

habit.core.machine_learning.evaluation.metrics.apply_youden_threshold(y_true: ndarray, y_pred_proba: ndarray, threshold: float) Dict[str, float | Dict[str, float]][源代码]

Apply a pre-determined Youden threshold.

habit.core.machine_learning.evaluation.metrics.calculate_metrics_at_target(y_true: ndarray, y_pred_proba: ndarray, target_metrics: Dict[str, float], threshold_selection: str = 'pareto+youden', fallback_to_closest: bool = True, distance_metric: str = 'euclidean') Dict[str, float | Dict[str, float]][源代码]

Calculate metrics at thresholds that achieve target metric values.

Enhanced version with: - Support for any metric (sensitivity, specificity, ppv, npv, f1_score, accuracy) - Fallback mechanism when no threshold satisfies all targets - Multiple threshold selection strategies

参数:
  • y_true -- True labels

  • y_pred_proba -- Predicted probabilities

  • target_metrics -- Target values, e.g., {'sensitivity': 0.91, 'specificity': 0.91}

  • threshold_selection -- Strategy for selecting best threshold from multiple candidates: - 'first': Use the first satisfying threshold - 'youden': Maximum Youden index among satisfying thresholds - 'pareto+youden': Pareto optimal with highest Youden (recommended)

  • fallback_to_closest -- If True, find closest threshold when no perfect match exists

  • distance_metric -- Distance metric for fallback ('euclidean', 'manhattan', 'max')

返回:

  • 'thresholds': Individual thresholds for each target metric

  • 'metrics_at_thresholds': Full metrics at each individual threshold

  • 'combined_results': Thresholds satisfying all targets

  • 'best_threshold': Selected best threshold (if multiple candidates exist)

  • 'closest_threshold': Fallback threshold (if no perfect match)

返回类型:

Dictionary containing

habit.core.machine_learning.evaluation.metrics.apply_target_threshold(y_true: ndarray, y_pred_proba: ndarray, threshold: float) Dict[str, float | Dict[str, float]][源代码]

Apply a pre-determined target threshold.

habit.core.machine_learning.evaluation.metrics.delong_roc_ci(y_true: ndarray, y_pred_proba: ndarray, alpha: float = 0.95) Tuple[float, ndarray][源代码]

Calculate DeLong confidence intervals for ROC curve.

habit.core.machine_learning.evaluation.metrics.calculate_net_benefit(y_true, y_pred_proba, threshold)[源代码]

Used for DCA plotting.

MultifileEvaluator使用示例 展示如何使用多文件评估工具评估多个模型的性能

class habit.core.machine_learning.workflows.comparison_workflow.ModelComparison(config: Dict[str, Any] | ModelComparisonConfig, evaluator: MultifileEvaluator, reporter: ReportExporter, threshold_manager: ThresholdManager, plot_manager: PlotManager, metrics_store: MetricsStore, logger: Any)[源代码]

基类:object

Tool for comparing and evaluating multiple machine learning models.

Note: Dependencies should be provided via ServiceConfigurator or explicitly.

__init__(config: Dict[str, Any] | ModelComparisonConfig, evaluator: MultifileEvaluator, reporter: ReportExporter, threshold_manager: ThresholdManager, plot_manager: PlotManager, metrics_store: MetricsStore, logger: Any) None[源代码]

Initialize the model comparison tool.

参数:
  • config -- Parsed config dict or validated config object.

  • evaluator -- MultifileEvaluator instance (required).

  • reporter -- ReportExporter instance (required).

  • threshold_manager -- ThresholdManager instance (required).

  • plot_manager -- PlotManager instance (required).

  • metrics_store -- MetricsStore instance (required).

  • logger -- Logger instance (required).

setup() None[源代码]

Setup the tool by reading prediction files and preparing data

save_merged_data() None[源代码]

Save merged data to a file

run_evaluation() None[源代码]

Run the entire evaluation process

save_all_metrics() None[源代码]

Save all metrics to a single JSON file

run() None[源代码]

可视化 (Visualization)

Plotting Module Provides various evaluation chart plotting functions

class habit.core.machine_learning.visualization.plotting.Plotter(output_dir: str, dpi: int = 600)[源代码]

基类:object

__init__(output_dir: str, dpi: int = 600) None[源代码]

Initialize the plotter

参数:
  • output_dir (str) -- Output directory path

  • dpi (int) -- Resolution for non-PDF format images

plot_roc_v2(models_data: Dict[str, Tuple[ndarray, ndarray]], save_name: str = 'ROC.pdf', title: str = 'test') None[源代码]

Plot ROC curves for a single dataset (optimized version)

参数:
  • models_data -- Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values

  • save_name -- Name of the file to save the plot

  • title -- Data type for title display ('train' or 'test')

plot_dca_v2(models_data: Dict[str, Tuple[ndarray, ndarray]], save_name: str = 'DCA.pdf', title: str = 'test') None[源代码]

Plot Decision Curve Analysis (DCA) for a single dataset (optimized version)

参数:
  • models_data -- Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values

  • save_name -- Name of the file to save the plot

  • title -- Data type for title display ('train' or 'test')

plot_confusion_matrix(y_true: ndarray, y_pred: ndarray, save_name: str = 'Confusion_Matrix.pdf', title: str = 'Confusion Matrix', class_names: List[str] | None = None, normalize: bool = False, cmap: str = 'Blues') None[源代码]

Plot confusion matrix

参数:
  • y_true (np.ndarray) -- True labels

  • y_pred (np.ndarray) -- Predicted labels

  • save_name (str) -- Name of the file to save the plot

  • title (str) -- Title of the plot

  • class_names (List[str]) -- Names of the classes (default: None, will use '0', '1' for binary classification)

  • normalize (bool) -- Whether to normalize the confusion matrix (default: False)

  • cmap (str) -- Colormap to use (default: 'Blues')

plot_calibration_v2(models_data: Dict[str, Tuple[ndarray, ndarray]], save_name: str = 'Calibration.pdf', n_bins: int = 5, title: str = 'test') None[源代码]

Plot calibration curves for a single dataset (optimized version)

参数:
  • models_data -- Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values

  • save_name -- Name of the file to save the plot

  • n_bins -- Number of bins to use for calibration curve

  • title -- Data type for title display ('train' or 'test')

plot_shap(model: Any, X: ndarray, feature_names: List[str], save_name: str = 'SHAP.pdf') None[源代码]

Plot SHAP values with bar and beeswarm plots

参数:
  • model (Any) -- Trained model

  • X (np.ndarray) -- Feature data

  • feature_names (List[str]) -- List of feature names

  • save_name (str) -- Name of the file to save the plot

plot_pr_curve(models_data: Dict[str, Tuple[ndarray, ndarray]], save_name: str = 'PR_curve.pdf', title: str = 'evaluation') None[源代码]

Plot Precision-Recall curve for multiple models :param models_data: Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values :param save_name: Name of the file to save the plot :param title: Data type for title display ('train', 'test', or 'evaluation')

Kaplan-Meier survival plotting utilities

Features aligned to top-tier medical imaging journals requirements: - Publication-quality styling (font, line widths, vector export) - KM curves with confidence bands - Log-rank p-value (two-group or multi-group) - Cox model hazard ratio (HR) with 95% CI (binary; pairwise vs reference if >2 groups) - Number-at-risk table - Median survival per group (optional)

Usage example (programmatic):

plotter = KMSurvivalPlotter(output_dir="./results/km") fig, ax = plotter.plot_km(

df=dataframe, time_col="os_time", event_col="os_event", group_col="risk_group", save_name="KM_OS.pdf", time_unit="Months",

)

class habit.core.machine_learning.visualization.km_survival.KMSurvivalPlotter(output_dir: 'str', dpi: 'int' = 600, font_family: 'str' = 'Arial', font_size: 'int' = 11)[源代码]

基类:object

output_dir: str
dpi: int = 600
font_family: str = 'Arial'
font_size: int = 11
plot_km(df: DataFrame, time_col: str, event_col: str, group_col: str, save_name: str = 'KM_Curve.pdf', time_unit: str = 'Months', group_order: Sequence | None = None, palette: Sequence[str] | None = None, show_ci: bool = True, show_risk_table: bool = True, show_hr: bool = False, hr_reference: str | None = None, figsize: Tuple[float, float] = (5.5, 5.0), y_label: str = 'Survival probability', x_label: str | None = None, xlim: Tuple[float, float] | None = None, ylim: Tuple[float, float] = (0.0, 1.0), annotate_median: bool = False, legend_loc: str = 'best', legend_ncol: int = 1, legend_outside: bool = False) Tuple[Figure, Axes][源代码]

Plot KM curves by groups with risk table and annotations.

参数:
  • df -- DataFrame containing survival data

  • time_col -- Duration column (numeric)

  • event_col -- Event column (1=event, 0=censored)

  • group_col -- Grouping column (categorical)

  • save_name -- Output file name; extension controls format

  • time_unit -- Label for x-axis (e.g., 'Months')

  • group_order -- Optional manual ordering of groups

  • palette -- Matplotlib/seaborn palette name or list of colors

  • show_ci -- Whether to draw confidence bands

  • show_risk_table -- Whether to render number-at-risk table

  • show_hr -- Whether to compute and display HR

  • hr_reference -- Reference group for HR (default: first in order)

  • figsize -- Figure size in inches

  • y_label -- Y-axis label

  • x_label -- X-axis label (defaults to time_unit)

  • xlim -- Optional x-axis range

  • ylim -- Y-axis range

  • annotate_median -- Add median survival to legend label

  • legend_loc -- Legend location ('best', 'upper right', 'lower left', etc.)

  • legend_ncol -- Number of columns in legend

  • legend_outside -- Whether to place legend outside the plot area

__init__(output_dir: str, dpi: int = 600, font_family: str = 'Arial', font_size: int = 11) None