habit.core.machine_learning.visualization.km_survival 源代码

"""
Kaplan-Meier survival plotting utilities

Features aligned to top-tier medical imaging journals requirements:
- Publication-quality styling (font, line widths, vector export)
- KM curves with confidence bands
- Log-rank p-value (two-group or multi-group)
- Cox model hazard ratio (HR) with 95% CI (binary; pairwise vs reference if >2 groups)
- Number-at-risk table
- Median survival per group (optional)

Usage example (programmatic):

    plotter = KMSurvivalPlotter(output_dir="./results/km")
    fig, ax = plotter.plot_km(
        df=dataframe,
        time_col="os_time",
        event_col="os_event",
        group_col="risk_group",
        save_name="KM_OS.pdf",
        time_unit="Months",
    )
"""

from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns


# from lifelines import KaplanMeierFitter, CoxPHFitter
# from lifelines.statistics import logrank_test, multivariate_logrank_test
# from lifelines.plotting import add_at_risk_counts
from matplotlib.lines import Line2D


def _ensure_output_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def _save_figure(fig: plt.Figure, output_dir: str, save_name: str, dpi: int) -> None:
    file_ext = os.path.splitext(save_name)[1].lower()
    output_path = os.path.join(output_dir, save_name)
    if file_ext == ".pdf":
        fig.savefig(output_path, bbox_inches="tight")
    elif file_ext in [".tif", ".tiff"]:
        fig.savefig(
            output_path,
            bbox_inches="tight",
            dpi=dpi,
            format="tif",
            compression="tiff_lzw",
        )
    else:
        fig.savefig(output_path, bbox_inches="tight", dpi=dpi)


def _as_color_list(palette: Optional[Sequence[str]], n: int) -> List:
    if palette is None:
        return sns.color_palette("Set1", n)
    if isinstance(palette, str):
        return sns.color_palette(palette, n)
    # assume list-like of colors
    if len(palette) < n:
        # repeat if not enough colors
        times = int(np.ceil(n / len(palette)))
        return list(palette) * times
    return list(palette)[:n]


[文档] @dataclass class KMSurvivalPlotter: output_dir: str dpi: int = 600 font_family: str = "Arial" font_size: int = 11 def __post_init__(self) -> None: _ensure_output_dir(self.output_dir) # Publication-quality defaults mpl.rcParams.update( { "font.family": self.font_family, "font.size": self.font_size, "axes.linewidth": 1.2, "axes.labelsize": self.font_size, "axes.titlesize": self.font_size + 1, "xtick.labelsize": self.font_size - 1, "ytick.labelsize": self.font_size - 1, "legend.fontsize": self.font_size - 1, "pdf.fonttype": 42, # embed TrueType "ps.fonttype": 42, "savefig.dpi": self.dpi, "figure.dpi": self.dpi, } )
[文档] def plot_km( self, df: pd.DataFrame, time_col: str, event_col: str, group_col: str, save_name: str = "KM_Curve.pdf", time_unit: str = "Months", group_order: Optional[Sequence] = None, palette: Optional[Sequence[str]] = None, show_ci: bool = True, show_risk_table: bool = True, show_hr: bool = False, hr_reference: Optional[str] = None, figsize: Tuple[float, float] = (5.5, 5.0), y_label: str = "Survival probability", x_label: Optional[str] = None, xlim: Optional[Tuple[float, float]] = None, ylim: Tuple[float, float] = (0.0, 1.0), annotate_median: bool = False, legend_loc: str = "best", legend_ncol: int = 1, legend_outside: bool = False, ) -> Tuple[plt.Figure, plt.Axes]: """ Plot KM curves by groups with risk table and annotations. Args: df: DataFrame containing survival data time_col: Duration column (numeric) event_col: Event column (1=event, 0=censored) group_col: Grouping column (categorical) save_name: Output file name; extension controls format time_unit: Label for x-axis (e.g., 'Months') group_order: Optional manual ordering of groups palette: Matplotlib/seaborn palette name or list of colors show_ci: Whether to draw confidence bands show_risk_table: Whether to render number-at-risk table show_hr: Whether to compute and display HR hr_reference: Reference group for HR (default: first in order) figsize: Figure size in inches y_label: Y-axis label x_label: X-axis label (defaults to time_unit) xlim: Optional x-axis range ylim: Y-axis range annotate_median: Add median survival to legend label legend_loc: Legend location ('best', 'upper right', 'lower left', etc.) legend_ncol: Number of columns in legend legend_outside: Whether to place legend outside the plot area """ # Validate columns for col in [time_col, event_col, group_col]: if col not in df.columns: raise KeyError(f"Column '{col}' not found in DataFrame") work_df = ( df[[time_col, event_col, group_col]] .dropna() .copy() ) # Prepare groups if group_order is None: groups = list(pd.unique(work_df[group_col])) else: groups = [g for g in group_order if g in set(work_df[group_col])] if len(groups) < 1: raise ValueError("No groups available after filtering") colors = _as_color_list(palette, len(groups)) # Set up figure and axis fig, ax = plt.subplots(figsize=figsize) # Fit KM per group km_fitters: List[KaplanMeierFitter] = [] legend_labels: List[str] = [] # Define censor marker style (can be customized) censor_styles = {"ms": 4, "marker": "+"} for idx, group_value in enumerate(groups): mask = work_df[group_col] == group_value if not np.any(mask): continue t = work_df.loc[mask, time_col].astype(float).values e = work_df.loc[mask, event_col].astype(int).values kmf = KaplanMeierFitter(label=str(group_value)) kmf.fit(t, event_observed=e) km_fitters.append(kmf) median_txt = "" if annotate_median: try: median_val = kmf.median_survival_time_ if np.isfinite(median_val): median_txt = f" (median {median_val:.1f})" except Exception: median_txt = "" legend_labels.append(f"{group_value}{median_txt}") kmf.plot_survival_function( ax=ax, ci_show=show_ci, linewidth=1.0, color=colors[idx], show_censors=True, censor_styles=censor_styles, ) # Axes formatting if x_label is None: x_label = f"Time ({time_unit})" ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_ylim(ylim) if xlim is not None: ax.set_xlim(xlim) ax.grid(True, linestyle="--", alpha=0.4) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) # Optimized Legend self._setup_legend(ax, legend_labels, colors[:len(legend_labels)], censor_styles, legend_loc, legend_ncol, legend_outside) # Statistical annotations p_text = self._compute_logrank_text(work_df, time_col, event_col, group_col) hr_text = self._compute_hr_text(work_df, time_col, event_col, group_col, groups, show_hr, hr_reference) annotation_lines = [] if p_text: annotation_lines.append(p_text) if hr_text: annotation_lines.append(hr_text) if annotation_lines: ax.text( 0.98, 0.04 if show_risk_table else 0.06, "\n".join(annotation_lines), transform=ax.transAxes, ha="right", va="bottom", fontsize=self.font_size - 1, bbox=dict(facecolor="white", alpha=0.9, boxstyle="round,pad=0.4"), ) # Number at risk table if show_risk_table and km_fitters: try: # Use default labels from fitters to avoid attribute errors across lifelines versions add_at_risk_counts(*km_fitters, ax=ax) except Exception: # Fallback: ignore risk table errors to avoid breaking plotting pass fig.tight_layout() _save_figure(fig, self.output_dir, save_name, self.dpi) return fig, ax
# ---- Internals ---- def _compute_logrank_text( self, df: pd.DataFrame, time_col: str, event_col: str, group_col: str, ) -> str: groups = pd.unique(df[group_col]) if len(groups) < 2: return "" try: if len(groups) == 2: g1, g2 = groups.tolist() d1 = df[df[group_col] == g1] d2 = df[df[group_col] == g2] res = logrank_test( d1[time_col], d2[time_col], event_observed_A=d1[event_col], event_observed_B=d2[event_col] ) return f"Log-rank p = {res.p_value:.3g}" else: res = multivariate_logrank_test( event_durations=df[time_col], groups=df[group_col], event_observed=df[event_col], ) return f"Log-rank p = {res.p_value:.3g}" except Exception: return "" def _compute_hr_text( self, df: pd.DataFrame, time_col: str, event_col: str, group_col: str, groups: Sequence, enabled: bool, hr_reference: Optional[str], ) -> str: if not enabled or len(groups) < 2: return "" try: # Prepare Cox input cox_df = df[[time_col, event_col, group_col]].copy() cox_df[time_col] = cox_df[time_col].astype(float) cox_df[event_col] = cox_df[event_col].astype(int) # Reference handling if hr_reference is None or hr_reference not in set(cox_df[group_col]): reference = groups[0] else: reference = hr_reference # Create dummies, drop reference dummies = pd.get_dummies(cox_df[group_col], drop_first=False) if reference not in dummies.columns: # Should not happen, fallback to drop_first=True dummies = pd.get_dummies(cox_df[group_col], drop_first=True) ref_dropped = True else: # Drop reference column to set it as baseline in Cox dummies = dummies.drop(columns=[reference]) ref_dropped = False design = pd.concat([cox_df[[time_col, event_col]], dummies], axis=1) cph = CoxPHFitter() cph.fit( design, duration_col=time_col, event_col=event_col, show_progress=False, robust=False, ) summary = cph.summary # Binary case if summary.shape[0] == 1: hr = float(np.exp(summary.loc[summary.index[0], "coef"])) lower = float(np.exp(summary.loc[summary.index[0], "coef lower 95%"])) upper = float(np.exp(summary.loc[summary.index[0], "coef upper 95%"])) return f"HR (95% CI) = {hr:.2f} ({lower:.2f}-{upper:.2f})" # Multi-group: report pairwise vs reference on one line if possible parts: List[str] = [] for idx in summary.index: hr = float(np.exp(summary.loc[idx, "coef"])) lower = float(np.exp(summary.loc[idx, "coef lower 95%"])) upper = float(np.exp(summary.loc[idx, "coef upper 95%"])) # idx is the dummy column name, e.g., 'GroupB' comp = idx if ref_dropped else idx parts.append(f"{comp} vs {reference}: {hr:.2f} ({lower:.2f}-{upper:.2f})") return "; ".join(parts) except Exception: return "" def _setup_legend( self, ax: plt.Axes, legend_labels: List[str], colors: List, censor_styles: Dict, legend_loc: str, legend_ncol: int, legend_outside: bool, ) -> None: """ Setup optimized legend with enhanced styling and positioning options. Args: ax: Matplotlib axes object legend_labels: List of legend labels colors: List of colors corresponding to each legend entry censor_styles: Dictionary containing censor marker styles legend_loc: Legend location string legend_ncol: Number of columns for legend legend_outside: Whether to place legend outside plot area """ if not legend_labels: return # Create custom line handles for consistent legend appearance custom_handles = [] # Extract marker info from censor_styles marker_type = censor_styles.get("marker", "+") marker_size = censor_styles.get("ms", 4) * 2 # Scale up for legend visibility for i, (label, color) in enumerate(zip(legend_labels, colors)): line_handle = Line2D([0], [0], color=color, linewidth=2.5, marker=marker_type, markersize=marker_size, markerfacecolor=color, markeredgecolor=color, markeredgewidth=2, label=label) custom_handles.append(line_handle) # Configure legend properties for better visual appearance legend_props = { "handles": custom_handles, # Use custom handles instead of default "frameon": True, "fancybox": True, # Rounded corners "shadow": True, # Drop shadow "framealpha": 0.95, # Slightly more opaque "facecolor": "white", "edgecolor": "lightgray", "ncol": legend_ncol, "fontsize": self.font_size - 1, "columnspacing": 1.2, # Space between columns "handletextpad": 0.5, # Space between marker and text "handlelength": 1.5, # Length of legend handles } if legend_outside: # Place legend outside the plot area (to the right) legend_props.update({ "bbox_to_anchor": (1.02, 1), "loc": "upper left", "borderaxespad": 0 }) else: # Intelligent positioning within plot area if legend_loc == "best": # Use matplotlib's automatic best location legend_props["loc"] = "best" else: # Use specified location legend_props["loc"] = legend_loc # For locations that might overlap with curves, add some padding if legend_loc in ["upper right", "lower right", "center right"]: legend_props["bbox_to_anchor"] = (0.98, 0.98) if "upper" in legend_loc else (0.98, 0.02) legend_props["loc"] = "upper right" if "upper" in legend_loc else "lower right" elif legend_loc in ["upper left", "lower left", "center left"]: legend_props["bbox_to_anchor"] = (0.02, 0.98) if "upper" in legend_loc else (0.02, 0.02) legend_props["loc"] = "upper left" if "upper" in legend_loc else "lower left" # Create the legend legend = ax.legend(**legend_props) # Additional styling for better appearance legend.get_frame().set_alpha(0.95) legend.get_frame().set_facecolor('white') legend.get_frame().set_linewidth(0.8) # Set border line width after legend creation