habit.core.machine_learning.visualization.plotting 源代码

"""
Plotting Module
Provides various evaluation chart plotting functions
"""

import os
from typing import Dict, List, Tuple, Union, Any
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import shap
from sklearn.metrics import roc_curve, precision_recall_curve, confusion_matrix, auc
from sklearn.calibration import calibration_curve  # Calibration curve related
from ..evaluation.metrics import calculate_net_benefit
from ....utils.visualization_utils import process_shap_explanation
from ....utils.font_config import setup_publication_font

[文档] class Plotter:
[文档] def __init__(self, output_dir: str, dpi: int = 600) -> None: """ Initialize the plotter Args: output_dir (str): Output directory path dpi (int): Resolution for non-PDF format images """ self.output_dir = output_dir self.dpi = dpi os.makedirs(output_dir, exist_ok=True) # Set plotting style with Arial font for publication quality setup_publication_font()
# plt.style.use('seaborn') # sns.set_context("paper", font_scale=1.2) def _build_model_curve_styles(self, model_names: List[str]) -> Dict[str, Dict[str, Any]]: """ Build high-contrast styles for multi-model line plots. This style generator is intentionally deterministic so the same model order always gets the same visual encoding within one plotting call. Design choices: - Use color-blind friendly palettes first. - If model count exceeds palette size, cycle line styles and markers. - Keep baseline curves (e.g., Treat All / Treat None) separate in callers. Args: model_names: Ordered model names to style. Returns: Dict[str, Dict[str, Any]]: Mapping from model name to style dict. """ # Primary high-contrast palettes (color-blind friendly). palette = ( sns.color_palette("colorblind", 10) + sns.color_palette("tab10", 10) + sns.color_palette("Set2", 8) ) line_styles = ["-", "--", "-.", ":"] markers = ["o", "s", "^", "D", "v", "P", "X", "*", "<", ">"] styles: Dict[str, Dict[str, Any]] = {} palette_size = len(palette) for idx, model_name in enumerate(model_names): color = palette[idx % palette_size] linestyle = line_styles[(idx // palette_size) % len(line_styles)] marker = markers[(idx // (palette_size * len(line_styles))) % len(markers)] styles[model_name] = { "color": color, "linestyle": linestyle, "marker": marker, } return styles
[文档] def plot_roc_v2(self, models_data: Dict[str, Tuple[np.ndarray, np.ndarray]], save_name: str = 'ROC.pdf', title: str = 'test') -> None: """ Plot ROC curves for a single dataset (optimized version) Args: models_data: Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values save_name: Name of the file to save the plot title: Data type for title display ('train' or 'test') """ # Create figure - optimized for SCI journal requirements (single column) plt.figure(figsize=(5, 5)) model_styles = self._build_model_curve_styles(list(models_data.keys())) # Plot ROC curves for each model for model_name, (y_true, y_pred_proba) in models_data.items(): fpr, tpr, _ = roc_curve(y_true, y_pred_proba) auc = np.trapz(tpr, fpr) style = model_styles[model_name] plt.plot( fpr, tpr, label=f'{model_name} (AUC = {auc:.2f})', linewidth=1.8, color=style["color"], linestyle=style["linestyle"], ) # Add diagonal line plt.plot([0, 1], [0, 1], 'k--', linewidth=1.5) plt.xlim([-0.02, 1.02]) plt.ylim([-0.02, 1.02]) plt.xlabel('False Positive Rate', fontsize=10, fontfamily='Arial') plt.ylabel('True Positive Rate', fontsize=10, fontfamily='Arial') # Set title based on data type plt.title(title, fontsize=11, fontfamily='Arial') plt.legend(loc="lower right", fontsize=9) plt.grid(True, linestyle='--', alpha=0.7) # plt.gca().set_facecolor('white') # Only show left and bottom spines and set their width to 1.5 ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(True) ax.spines['left'].set_visible(True) ax.spines['bottom'].set_linewidth(1.5) ax.spines['left'].set_linewidth(1.5) # Save figure plt.tight_layout() # 根据文件扩展名决定是否应用压缩和DPI设置 file_ext = os.path.splitext(save_name)[1].lower() if file_ext == '.pdf': plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight') elif file_ext in ['.tif', '.tiff']: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi, format='tif', compression='tiff_lzw') else: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi) plt.close()
[文档] def plot_dca_v2(self, models_data: Dict[str, Tuple[np.ndarray, np.ndarray]], save_name: str = 'DCA.pdf', title: str = 'test') -> None: """ Plot Decision Curve Analysis (DCA) for a single dataset (optimized version) Args: models_data: Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values save_name: Name of the file to save the plot title: Data type for title display ('train' or 'test') """ # Create figure - optimized for SCI journal requirements (single column) plt.figure(figsize=(5, 5)) # Define threshold range thresholds = np.linspace(0, 1, 100) model_styles = self._build_model_curve_styles(list(models_data.keys())) # Extract y_true as reference (any model can be used since y_true should be consistent) if not models_data: print("No data provided for DCA plot") return # 检测模型的输出概率是否超过0-1,如果超过则进行归一化 for model_name, (y_true, y_pred_proba) in models_data.items(): if np.any(y_pred_proba > 1) or np.any(y_pred_proba < 0): print(f"Warning: Model {model_name} has predicted probabilities outside [0, 1]") y_pred_proba = (y_pred_proba - np.min(y_pred_proba)) / (np.max(y_pred_proba) - np.min(y_pred_proba)) models_data[model_name] = (y_true, y_pred_proba) y_true = next(iter(models_data.values()))[0] # Calculate and plot "Treat All" curve net_benefit_all = np.array([calculate_net_benefit(y_true, np.ones_like(y_true), t) for t in thresholds]) plt.plot(thresholds, net_benefit_all, 'k--', label='Treat All', linewidth=1.5) # Calculate and plot "Treat None" curve net_benefit_none = np.array([calculate_net_benefit(y_true, np.zeros_like(y_true), t) for t in thresholds]) plt.plot(thresholds, net_benefit_none, 'k-', label='Treat None', linewidth=1.5) # Plot decision curves for each model for model_name, (y_true, y_pred_proba) in models_data.items(): net_benefits = np.array([calculate_net_benefit(y_true, y_pred_proba, t) for t in thresholds]) style = model_styles[model_name] plt.plot( thresholds, net_benefits, linewidth=1.8, label=model_name, color=style["color"], linestyle=style["linestyle"], ) # Beautify the plot plt.xlabel('Threshold Probability', fontsize=10, fontfamily='Arial') plt.ylabel('Net Benefit', fontsize=10, fontfamily='Arial') plt.title(title, fontsize=11, fontfamily='Arial') plt.grid(True, linestyle='--', alpha=0.7) # plt.gca().set_facecolor('#f8f9fa') plt.legend(loc='best', fontsize=9) # Only show left and bottom spines and set their width to 1.5 ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(True) ax.spines['left'].set_visible(True) ax.spines['bottom'].set_linewidth(1.5) ax.spines['left'].set_linewidth(1.5) # Safely set y-axis range, handling possible NaN or Inf y_min = -0.05 # Default minimum y_max = 0.5 # Default maximum # Safely get minimum of net_benefit_none if len(net_benefit_none) > 0 and np.isfinite(net_benefit_none).any(): none_min = np.nanmin(net_benefit_none[np.isfinite(net_benefit_none)]) if np.isfinite(none_min): y_min = min(y_min, none_min) # Safely get maximum of net_benefit_all if len(net_benefit_all) > 0 and np.isfinite(net_benefit_all).any(): all_max = np.nanmax(net_benefit_all[np.isfinite(net_benefit_all)]) if np.isfinite(all_max): y_max = max(y_max, all_max + 0.1) plt.ylim([y_min, y_max]) # Save image plt.tight_layout() # 根据文件扩展名决定是否应用压缩和DPI设置 file_ext = os.path.splitext(save_name)[1].lower() if file_ext == '.pdf': plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight') elif file_ext in ['.tif', '.tiff']: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi, format='tif', compression='tiff_lzw') else: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi) plt.close()
[文档] def plot_confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray, save_name: str = 'Confusion_Matrix.pdf', title: str = 'Confusion Matrix', class_names: List[str] = None, normalize: bool = False, cmap: str = 'Blues') -> None: """ Plot confusion matrix Args: y_true (np.ndarray): True labels y_pred (np.ndarray): Predicted labels save_name (str): Name of the file to save the plot title (str): Title of the plot class_names (List[str]): Names of the classes (default: None, will use '0', '1' for binary classification) normalize (bool): Whether to normalize the confusion matrix (default: False) cmap (str): Colormap to use (default: 'Blues') """ # Calculate confusion matrix cm = confusion_matrix(y_true, y_pred) # Set class names if not provided if class_names is None: if cm.shape[0] == 2: # Binary classification class_names = ['Negative', 'Positive'] else: # Multi-class classification class_names = [str(i) for i in range(cm.shape[0])] # Normalize the confusion matrix if requested if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] fmt = '.2f' title = f'Normalized {title}' else: fmt = 'd' # Create figure and plot confusion matrix - optimized for SCI journal plt.figure(figsize=(5, 4)) sns.heatmap(cm, annot=True, fmt=fmt, cmap=cmap, xticklabels=class_names, yticklabels=class_names, cbar=True, square=True, linewidths=0.5) # Add labels and title plt.xlabel('Predicted Label', fontsize=10, fontfamily='Arial') plt.ylabel('True Label', fontsize=10, fontfamily='Arial') plt.title(title, fontsize=11, fontfamily='Arial') # Calculate and add metrics to the plot tn, fp, fn, tp = cm.ravel() sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 accuracy = (tp + tn) / (tp + tn + fp + fn) plt.figtext(0.5, 0.01, f'Accuracy: {accuracy:.2f}, Sensitivity: {sensitivity:.2f}, Specificity: {specificity:.2f}', ha='center', fontsize=8, fontfamily='Arial', bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5')) # Adjust plot aesthetics plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Save figure file_ext = os.path.splitext(save_name)[1].lower() if file_ext == '.pdf': plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight') elif file_ext in ['.tif', '.tiff']: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi, format='tif', compression='tiff_lzw') else: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi) plt.close()
[文档] def plot_calibration_v2(self, models_data: Dict[str, Tuple[np.ndarray, np.ndarray]], save_name: str = 'Calibration.pdf', n_bins: int = 5, title: str = 'test') -> None: """ Plot calibration curves for a single dataset (optimized version) Args: models_data: Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values save_name: Name of the file to save the plot n_bins: Number of bins to use for calibration curve title: Data type for title display ('train' or 'test') """ # Create figure - optimized for SCI journal requirements (single column) plt.figure(figsize=(5, 5)) model_styles = self._build_model_curve_styles(list(models_data.keys())) # Plot calibration curves for each model for model_name, (y_true, y_pred_proba) in models_data.items(): # Ensure predicted probabilities are within 0-1 range y_pred_normalized = (y_pred_proba - np.min(y_pred_proba)) / (np.max(y_pred_proba) - np.min(y_pred_proba)) # Calculate calibration curve prob_true, prob_pred = calibration_curve(y_true, y_pred_normalized, n_bins=n_bins, strategy='quantile') style = model_styles[model_name] plt.plot( prob_pred, prob_true, linewidth=1.8, markersize=5, marker=style["marker"], color=style["color"], linestyle=style["linestyle"], label=model_name, ) # Add ideal calibration line and beautify the plot plt.plot([0, 1], [0, 1], 'k--', linewidth=1.5, label='Perfectly Calibrated') plt.xlabel('Mean Predicted Probability', fontsize=10, fontfamily='Arial') plt.ylabel('Positive Sample Proportion', fontsize=10, fontfamily='Arial') # Set title based on data type plt.title(title, fontsize=11, fontfamily='Arial') plt.legend(loc='best', frameon=True, facecolor='white', framealpha=0.9, fontsize=9) plt.grid(True, linestyle='--', alpha=0.7) # plt.gca().set_facecolor('#f8f9fa') # Only show left and bottom spines and set their width to 1.5 ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(True) ax.spines['left'].set_visible(True) ax.spines['bottom'].set_linewidth(1.5) ax.spines['left'].set_linewidth(1.5) # Expand axis range plt.xlim([-0.05, 1.05]) plt.ylim([-0.05, 1.05]) # Save image plt.tight_layout() # 根据文件扩展名决定是否应用压缩和DPI设置 file_ext = os.path.splitext(save_name)[1].lower() if file_ext == '.pdf': plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight') elif file_ext in ['.tif', '.tiff']: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi, format='tif', compression='tiff_lzw') else: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi) plt.close()
[文档] def plot_shap(self, model: Any, X: np.ndarray, feature_names: List[str], save_name: str = 'SHAP.pdf') -> None: """ Plot SHAP values with bar and beeswarm plots Args: model (Any): Trained model X (np.ndarray): Feature data feature_names (List[str]): List of feature names save_name (str): Name of the file to save the plot """ # Get model type - if not available, try to infer from model object model_type = getattr(model, 'model_type', None) # Calculate SHAP values based on model type if model_type == 'linear': # For custom linear models, try to access the underlying sklearn model if hasattr(model, 'model'): # Access the internal sklearn model sklearn_model = model.model explainer = shap.LinearExplainer(sklearn_model, X) elif hasattr(model, 'coef_') and hasattr(model, 'intercept_'): # If model has coefficients and intercept directly explainer = shap.LinearExplainer((model.coef_, model.intercept_), X) else: # Fallback to KernelExplainer if we can't access the model structure explainer = shap.KernelExplainer(model.predict_proba, X) elif model_type == 'tree': # For tree-based models if hasattr(model, 'model'): # Access the internal sklearn model sklearn_model = model.model explainer = shap.TreeExplainer(sklearn_model) else: # Try to use the model directly explainer = shap.TreeExplainer(model) else: # Default to KernelExplainer for other model types explainer = shap.KernelExplainer(model.predict_proba, X) # Get SHAP values shap_values = explainer.shap_values(X) # Process SHAP explanation for consistency shap_values = process_shap_explanation(shap_values) # Plot 1: Feature importance bar plot - optimized for SCI journal plt.figure(figsize=(6, 5)) plt.title('Feature Importance', fontsize=11, fontfamily='Arial') shap.summary_plot( shap_values, X, feature_names=feature_names, plot_type="bar", show=False ) plt.tight_layout() bar_filename = os.path.splitext(save_name)[0] + '_bar' + os.path.splitext(save_name)[1] self._save_figure(bar_filename) plt.close() # Plot 2: Beeswarm plot - optimized for SCI journal plt.figure(figsize=(6, 5)) plt.title('Feature Impact Distribution', fontsize=11, fontfamily='Arial') shap.summary_plot( shap_values, X, feature_names=feature_names, show=False ) plt.tight_layout() self._save_figure(save_name) plt.close()
def _save_figure(self, save_name: str) -> None: """ Helper method to save figures with appropriate format and DPI Args: save_name (str): Name of the file to save """ file_ext = os.path.splitext(save_name)[1].lower() if file_ext == '.pdf': plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight') elif file_ext in ['.tif', '.tiff']: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi, format='tif', compression='tiff_lzw') else: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi)
[文档] def plot_pr_curve(self, models_data: Dict[str, Tuple[np.ndarray, np.ndarray]], save_name: str = 'PR_curve.pdf', title: str = 'evaluation') -> None: """ Plot Precision-Recall curve for multiple models Args: models_data: Dictionary with model names as keys and (y_true, y_pred_proba) tuples as values save_name: Name of the file to save the plot title: Data type for title display ('train', 'test', or 'evaluation') """ # Create figure - optimized for SCI journal requirements (single column) plt.figure(figsize=(5, 5)) model_styles = self._build_model_curve_styles(list(models_data.keys())) # Plot PR curves for each model for model_name, (y_true, y_pred_proba) in models_data.items(): precision, recall, _ = precision_recall_curve(y_true, y_pred_proba, drop_intermediate=True) # Calculate average precision score AUPRC = auc(recall, precision) style = model_styles[model_name] plt.plot( recall, precision, linewidth=1.8, color=style["color"], linestyle=style["linestyle"], label=f'{model_name} (AUPRC = {AUPRC:.2f})', ) # Beautify the plot plt.xlabel('Recall', fontsize=10, fontfamily='Arial') # 修改X轴标签 plt.ylabel('Precision', fontsize=10, fontfamily='Arial') plt.title(f'{title}', fontsize=11, fontfamily='Arial') plt.legend(loc='best', fontsize=9) plt.grid(True, linestyle='--', alpha=0.7) # Set axis limits for left-to-right, bottom-to-top direction plt.xlim([-0.02, 1.02]) plt.ylim([-0.02, 1.02]) # Only show left and bottom spines and set their width to 1.5 ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(True) ax.spines['left'].set_visible(True) ax.spines['bottom'].set_linewidth(1.5) ax.spines['left'].set_linewidth(1.5) plt.tight_layout() # 根据文件扩展名决定是否应用压缩和DPI设置 file_ext = os.path.splitext(save_name)[1].lower() if file_ext == '.pdf': plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight') elif file_ext in ['.tif', '.tiff']: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi, format='tif', compression='tiff_lzw') else: plt.savefig(os.path.join(self.output_dir, save_name), bbox_inches='tight', dpi=self.dpi) plt.close()