machine_learning 模块
Machine Learning module for HABIT package.
This module aggregates key components for ML workflows, including model training, evaluation, and feature selection.
工作流 (Workflows)
工作流类封装了完整的机器学习过程,包括训练、验证和预测。
Standard Machine Learning Workflow (Train/Test Split). Inherits from BaseWorkflow for consistent infrastructure.
- class habit.core.machine_learning.workflows.holdout_workflow.MachineLearningWorkflow(config: MLConfig)[源代码]
基类:
BaseWorkflow
Advanced K-Fold Cross-Validation Workflow. Inherits from BaseWorkflow for consistent infrastructure.
- class habit.core.machine_learning.workflows.kfold_workflow.MachineLearningKFoldWorkflow(config: MLConfig)[源代码]
基类:
BaseWorkflow
Prediction Workflow Module Handles loading trained models and making predictions on new data.
- class habit.core.machine_learning.workflows.prediction_workflow.PredictionWorkflow(config: PredictionConfig, logger: Logger | None = None)[源代码]
基类:
objectWorkflow for running model predictions.
- config
Configuration object.
- Type:
PredictionConfig
- logger
Logger instance.
- Type:
- class habit.core.machine_learning.base_workflow.BaseWorkflow(config: MLConfig | Dict[str, Any], module_name: str)[源代码]
基类:
ABCAbstract Base Class for all Machine Learning Workflows. Handles infrastructure like logging, data loading, and basic results persistence.
模型工厂 (Model Factory)
Model Factory Factory class for creating model instances
- class habit.core.machine_learning.models.factory.ModelFactory[源代码]
基类:
objectFactory class for creating model instances
- classmethod register(name: str)[源代码]
Register a model class
- 参数:
name -- Model name
- 返回:
Decorator function
Ensemble Model Wrapper for K-Fold Cross Validation. Allows treating a collection of K-fold models as a single scikit-learn estimator.
- class habit.core.machine_learning.models.ensemble.HabitEnsembleModel(estimators: List[Any], voting: str = 'soft')[源代码]
基类:
BaseEstimator,ClassifierMixinAn ensemble wrapper that aggregates predictions from multiple fitted pipelines. Used primarily to wrap K-Fold cross-validation results into a single predict-ready object.
- estimators
List of fitted scikit-learn pipelines/models.
- Type:
List[Any]
- fit(X, y=None)[源代码]
No-op fit method. We assume the estimators passed in __init__ are ALREADY fitted. This allows us to wrap the K-Fold results directly.
- predict_proba(X)[源代码]
Predict class probabilities for X.
- 返回:
array-like of shape (n_samples, n_classes)
- set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') HabitEnsembleModel
Configure whether metadata should be requested to be passed to the
scoremethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed toscoreif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it toscore.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
评估工具 (Evaluation)
Model Evaluation Module Provides functions for model training, evaluation, and result analysis
- class habit.core.machine_learning.evaluation.model_evaluation.ModelEvaluator(output_dir: str)[源代码]
基类:
object- __init__(output_dir: str)[源代码]
Initialize the model evaluator.
- 参数:
output_dir (str) -- Directory where evaluation results and plots will be saved
- evaluate(model: Any, X: DataFrame, y: Series, dataset_name: str = 'test') Dict[str, Any][源代码]
Evaluate a single model on a single dataset.
- plot_curves(model_data: Dict[str, Dict[str, Dict[str, List]]], curve_type: Literal['roc', 'dca', 'calibration', 'pr', 'all'] = 'all', title: str = 'evaluation', output_dir: str | None = None, prefix: str = '', n_bins: int = 10) Dict[str, str][源代码]
Plot various evaluation curves using methods from plotting.py
- 参数:
model_data (Dict) -- Dictionary containing model evaluation results Format: {dataset_name: {model_name: {'y_true': [...], 'y_pred_proba': [...]}}}
curve_type (Literal) -- Type of curve to plot, can be 'roc', 'dca', 'calibration', 'pr', or 'all'
title (str) -- Keyword for chart title
output_dir (Optional[str]) -- Output directory, defaults to self.output_dir
prefix (str) -- Prefix for output filenames
n_bins (int) -- Number of bins for calibration curve
- 返回:
Dictionary containing paths to generated chart files
- 返回类型:
- class habit.core.machine_learning.evaluation.model_evaluation.MultifileEvaluator(output_dir: str)[源代码]
基类:
object- read_prediction_files(files_config: List[Dict]) MultifileEvaluator[源代码]
从多个文件读取预测结果
- 参数:
files_config (List[Dict]) -- 文件配置列表,每个元素包含: - path: 文件路径 - model_name: 模型名称 - subject_id_col: 受试者ID列名 - label_col: 真实标签列名 - prob_col: 预测概率列名 - pred_col: 预测标签列名(可选)
- 返回:
自身实例,用于方法链式调用
- 返回类型:
- save_merged_data(filename: str = 'merged_predictions.csv') None[源代码]
保存合并后的数据到CSV文件
- 参数:
filename (str) -- 输出文件名
- plot_calibration(save_name: str = 'Calibration.pdf', n_bins: int = 5, title: str = 'evaluation') None[源代码]
为所有模型绘制校准曲线
Metrics Calculation Module Provides a registry-based system for calculating model evaluation metrics.
Optimizations: - Confusion matrix caching for performance - Extended target metrics support (PPV, NPV, F1, etc.) - Fallback mechanism when no threshold satisfies all targets - Pareto optimal threshold selection - Category-based metric filtering
- class habit.core.machine_learning.evaluation.metrics.MetricsCache(y_true: ndarray, y_pred: ndarray, y_prob: ndarray)[源代码]
基类:
objectCache for confusion matrix and derived metrics to avoid repeated calculations. Provides ~8x performance improvement when calculating multiple metrics.
- habit.core.machine_learning.evaluation.metrics.register_metric(name: str, display_name: str, category: str = 'basic')[源代码]
Decorator to register a metric function.
- 参数:
name -- Internal unique key for the metric
display_name -- Pretty name for reports and plots
category -- 'basic', 'statistical', or 'clinical'
- habit.core.machine_learning.evaluation.metrics.calc_sensitivity(y_true, y_pred, y_prob, cm=None)[源代码]
Calculate sensitivity (recall, true positive rate).
- habit.core.machine_learning.evaluation.metrics.calc_specificity(y_true, y_pred, y_prob, cm=None)[源代码]
Calculate specificity (true negative rate).
- habit.core.machine_learning.evaluation.metrics.calc_ppv(y_true, y_pred, y_prob, cm=None)[源代码]
Calculate Positive Predictive Value (precision).
- habit.core.machine_learning.evaluation.metrics.calc_npv(y_true, y_pred, y_prob, cm=None)[源代码]
Calculate Negative Predictive Value.
- habit.core.machine_learning.evaluation.metrics.calc_f1(y_true, y_pred, y_prob, cm=None)[源代码]
Calculate F1-score (harmonic mean of precision and recall).
- habit.core.machine_learning.evaluation.metrics.calculate_metrics(y_true: ndarray | PredictionContainer, y_pred: ndarray | None = None, y_pred_proba: ndarray | None = None, categories: List[str] | None = None, use_cache: bool = True) Dict[str, float][源代码]
Calculate all registered metrics using PredictionContainer.
- 参数:
y_true -- True labels or PredictionContainer
y_pred -- Predicted labels (optional if using PredictionContainer)
y_pred_proba -- Predicted probabilities (optional if using PredictionContainer)
categories -- Filter metrics by category, e.g., ['basic', 'statistical'] If None, calculate all metrics
use_cache -- If True, use confusion matrix caching for better performance
- 返回:
Dictionary of metric_name -> value
- habit.core.machine_learning.evaluation.metrics.apply_threshold(y_true: ndarray, y_pred_proba: ndarray, threshold: float) Dict[str, float][源代码]
Apply a given threshold to predicted probabilities and calculate metrics.
- habit.core.machine_learning.evaluation.metrics.calculate_metrics_youden(y_true: ndarray, y_pred_proba: ndarray) Dict[str, float | Dict[str, float]][源代码]
Calculate metrics based on the optimal Youden index.
- habit.core.machine_learning.evaluation.metrics.apply_youden_threshold(y_true: ndarray, y_pred_proba: ndarray, threshold: float) Dict[str, float | Dict[str, float]][源代码]
Apply a pre-determined Youden threshold.
- habit.core.machine_learning.evaluation.metrics.calculate_metrics_at_target(y_true: ndarray, y_pred_proba: ndarray, target_metrics: Dict[str, float], threshold_selection: str = 'pareto+youden', fallback_to_closest: bool = True, distance_metric: str = 'euclidean') Dict[str, float | Dict[str, float]][源代码]
Calculate metrics at thresholds that achieve target metric values.
Enhanced version with: - Support for any metric (sensitivity, specificity, ppv, npv, f1_score, accuracy) - Fallback mechanism when no threshold satisfies all targets - Multiple threshold selection strategies
- 参数:
y_true -- True labels
y_pred_proba -- Predicted probabilities
target_metrics -- Target values, e.g., {'sensitivity': 0.91, 'specificity': 0.91}
threshold_selection -- Strategy for selecting best threshold from multiple candidates: - 'first': Use the first satisfying threshold - 'youden': Maximum Youden index among satisfying thresholds - 'pareto+youden': Pareto optimal with highest Youden (recommended)
fallback_to_closest -- If True, find closest threshold when no perfect match exists
distance_metric -- Distance metric for fallback ('euclidean', 'manhattan', 'max')
- 返回:
'thresholds': Individual thresholds for each target metric
'metrics_at_thresholds': Full metrics at each individual threshold
'combined_results': Thresholds satisfying all targets
'best_threshold': Selected best threshold (if multiple candidates exist)
'closest_threshold': Fallback threshold (if no perfect match)
- 返回类型:
Dictionary containing
- habit.core.machine_learning.evaluation.metrics.apply_target_threshold(y_true: ndarray, y_pred_proba: ndarray, threshold: float) Dict[str, float | Dict[str, float]][源代码]
Apply a pre-determined target threshold.
- habit.core.machine_learning.evaluation.metrics.delong_roc_ci(y_true: ndarray, y_pred_proba: ndarray, alpha: float = 0.95) Tuple[float, ndarray][源代码]
Calculate DeLong confidence intervals for ROC curve.
- habit.core.machine_learning.evaluation.metrics.calculate_net_benefit(y_true, y_pred_proba, threshold)[源代码]
Used for DCA plotting.
MultifileEvaluator使用示例 展示如何使用多文件评估工具评估多个模型的性能
- class habit.core.machine_learning.workflows.comparison_workflow.ModelComparison(config: Dict[str, Any] | ModelComparisonConfig, evaluator: MultifileEvaluator, reporter: ReportExporter, threshold_manager: ThresholdManager, plot_manager: PlotManager, metrics_store: MetricsStore, logger: Any)[源代码]
基类:
objectTool for comparing and evaluating multiple machine learning models.
Note: Dependencies should be provided via ServiceConfigurator or explicitly.
- __init__(config: Dict[str, Any] | ModelComparisonConfig, evaluator: MultifileEvaluator, reporter: ReportExporter, threshold_manager: ThresholdManager, plot_manager: PlotManager, metrics_store: MetricsStore, logger: Any) None[源代码]
Initialize the model comparison tool.
- 参数:
config -- Parsed config dict or validated config object.
evaluator -- MultifileEvaluator instance (required).
reporter -- ReportExporter instance (required).
threshold_manager -- ThresholdManager instance (required).
plot_manager -- PlotManager instance (required).
metrics_store -- MetricsStore instance (required).
logger -- Logger instance (required).
可视化 (Visualization)
Plotting Module Provides various evaluation chart plotting functions
- class habit.core.machine_learning.visualization.plotting.Plotter(output_dir: str, dpi: int = 600)[源代码]
基类:
object- plot_roc_v2(models_data: Dict[str, Tuple[ndarray, ndarray]], save_name: str = 'ROC.pdf', title: str = 'test') None[源代码]
Plot ROC curves for a single dataset (optimized version)
- 参数:
models_data -- Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values
save_name -- Name of the file to save the plot
title -- Data type for title display ('train' or 'test')
- plot_dca_v2(models_data: Dict[str, Tuple[ndarray, ndarray]], save_name: str = 'DCA.pdf', title: str = 'test') None[源代码]
Plot Decision Curve Analysis (DCA) for a single dataset (optimized version)
- 参数:
models_data -- Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values
save_name -- Name of the file to save the plot
title -- Data type for title display ('train' or 'test')
- plot_confusion_matrix(y_true: ndarray, y_pred: ndarray, save_name: str = 'Confusion_Matrix.pdf', title: str = 'Confusion Matrix', class_names: List[str] | None = None, normalize: bool = False, cmap: str = 'Blues') None[源代码]
Plot confusion matrix
- 参数:
y_true (np.ndarray) -- True labels
y_pred (np.ndarray) -- Predicted labels
save_name (str) -- Name of the file to save the plot
title (str) -- Title of the plot
class_names (List[str]) -- Names of the classes (default: None, will use '0', '1' for binary classification)
normalize (bool) -- Whether to normalize the confusion matrix (default: False)
cmap (str) -- Colormap to use (default: 'Blues')
- plot_calibration_v2(models_data: Dict[str, Tuple[ndarray, ndarray]], save_name: str = 'Calibration.pdf', n_bins: int = 5, title: str = 'test') None[源代码]
Plot calibration curves for a single dataset (optimized version)
- 参数:
models_data -- Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values
save_name -- Name of the file to save the plot
n_bins -- Number of bins to use for calibration curve
title -- Data type for title display ('train' or 'test')
- plot_shap(model: Any, X: ndarray, feature_names: List[str], save_name: str = 'SHAP.pdf') None[源代码]
Plot SHAP values with bar and beeswarm plots
- plot_pr_curve(models_data: Dict[str, Tuple[ndarray, ndarray]], save_name: str = 'PR_curve.pdf', title: str = 'evaluation') None[源代码]
Plot Precision-Recall curve for multiple models :param models_data: Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values :param save_name: Name of the file to save the plot :param title: Data type for title display ('train', 'test', or 'evaluation')
Kaplan-Meier survival plotting utilities
Features aligned to top-tier medical imaging journals requirements: - Publication-quality styling (font, line widths, vector export) - KM curves with confidence bands - Log-rank p-value (two-group or multi-group) - Cox model hazard ratio (HR) with 95% CI (binary; pairwise vs reference if >2 groups) - Number-at-risk table - Median survival per group (optional)
Usage example (programmatic):
plotter = KMSurvivalPlotter(output_dir="./results/km") fig, ax = plotter.plot_km(
df=dataframe, time_col="os_time", event_col="os_event", group_col="risk_group", save_name="KM_OS.pdf", time_unit="Months",
)
- class habit.core.machine_learning.visualization.km_survival.KMSurvivalPlotter(output_dir: 'str', dpi: 'int' = 600, font_family: 'str' = 'Arial', font_size: 'int' = 11)[源代码]
基类:
object- plot_km(df: DataFrame, time_col: str, event_col: str, group_col: str, save_name: str = 'KM_Curve.pdf', time_unit: str = 'Months', group_order: Sequence | None = None, palette: Sequence[str] | None = None, show_ci: bool = True, show_risk_table: bool = True, show_hr: bool = False, hr_reference: str | None = None, figsize: Tuple[float, float] = (5.5, 5.0), y_label: str = 'Survival probability', x_label: str | None = None, xlim: Tuple[float, float] | None = None, ylim: Tuple[float, float] = (0.0, 1.0), annotate_median: bool = False, legend_loc: str = 'best', legend_ncol: int = 1, legend_outside: bool = False) Tuple[Figure, Axes][源代码]
Plot KM curves by groups with risk table and annotations.
- 参数:
df -- DataFrame containing survival data
time_col -- Duration column (numeric)
event_col -- Event column (1=event, 0=censored)
group_col -- Grouping column (categorical)
save_name -- Output file name; extension controls format
time_unit -- Label for x-axis (e.g., 'Months')
group_order -- Optional manual ordering of groups
palette -- Matplotlib/seaborn palette name or list of colors
show_ci -- Whether to draw confidence bands
show_risk_table -- Whether to render number-at-risk table
show_hr -- Whether to compute and display HR
hr_reference -- Reference group for HR (default: first in order)
figsize -- Figure size in inches
y_label -- Y-axis label
x_label -- X-axis label (defaults to time_unit)
xlim -- Optional x-axis range
ylim -- Y-axis range
annotate_median -- Add median survival to legend label
legend_loc -- Legend location ('best', 'upper right', 'lower left', etc.)
legend_ncol -- Number of columns in legend
legend_outside -- Whether to place legend outside the plot area