habit.utils.visualization 源代码

"""
Visualization utilities for habitat analysis
"""

import os
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Tuple, Optional, Union
from .font_config import setup_publication_font, get_font_config

# Setup publication-quality Arial font
setup_publication_font()


[文档] def plot_cluster_scores(scores_dict: Dict[str, List[float]], cluster_range: List[int], methods: Optional[Union[List[str], str]] = None, clustering_algorithm: str = 'kmeans', figsize: Tuple[int, int] = (10, 10), outdir: Optional[str] = None, save_path: Optional[str] = None, show: bool = True, dpi: int = 600, best_n_clusters: Optional[Dict[str, int]] = None): """ Plot the scoring curves for cluster evaluation Args: scores_dict: Dictionary of scores, with method names as keys and score lists as values cluster_range: Range of cluster numbers to evaluate methods: Methods to plot, can be a string or list of strings, None means plot all methods clustering_algorithm: Name of the clustering algorithm figsize: Size of the figure outdir: Directory to save figures, None means do not save save_path: Explicit file path to save a single figure (overrides outdir) show: Whether to display the figure dpi: Image resolution best_n_clusters: Precomputed best cluster number per method to mark on the plot """ from habit.core.habitat_analysis.algorithms.cluster_validation_methods import get_method_description, get_optimization_direction # If methods is None, use all methods in scores_dict if methods is None: methods = list(scores_dict.keys()) elif isinstance(methods, str): methods = [methods] figdir = None if outdir: figdir = os.path.join(outdir, 'visualizations', 'habitat_clustering') os.makedirs(figdir, exist_ok=True) # Plot for each method for i, method in enumerate(methods): if method not in scores_dict: continue _, ax = plt.subplots(1, 1, figsize=figsize) scores = scores_dict[method] # Plot score curve ax.plot(cluster_range, scores, 'o-', linewidth=2, markersize=8) # Get optimization direction for the method optimization = get_optimization_direction(clustering_algorithm, method) # Mark the optimal number of clusters. # Prefer externally provided best_n_clusters to avoid recomputing the selection logic. best_n_clusters_value: Optional[int] = None # Provide a readable criterion label even when the best index is supplied. if optimization in ['kneedle', 'inertia', 'elbow']: criterion = "Kneedle" elif optimization == 'maximize': criterion = "Maximum" elif optimization == 'minimize': criterion = "Minimum" else: criterion = "Maximum" if best_n_clusters is not None and method in best_n_clusters: best_n_clusters_value = best_n_clusters[method] else: # Fallback to internal logic to keep compatibility for other call sites. if optimization == 'maximize': best_idx = int(np.argmax(scores)) elif optimization == 'minimize': best_idx = int(np.argmin(scores)) else: best_idx = int(np.argmax(scores)) best_n_clusters_value = cluster_range[best_idx] # Map the provided best cluster number back to score index. if best_n_clusters_value in cluster_range: best_idx = cluster_range.index(best_n_clusters_value) best_score = scores[best_idx] else: # If the best cluster number is outside the plotting range, skip marking. best_score = None criterion = "N/A" # Mark the optimal point on the plot if best_score is not None: ax.plot(best_n_clusters_value, best_score, 'rx', markersize=12, markeredgewidth=3) # Set title and labels method_desc = get_method_description(clustering_algorithm, method) ax.set_title( f"{method_desc}\nOptimal Clusters = {best_n_clusters_value} ({criterion})", fontfamily='Arial' ) ax.set_xlabel("Number of Clusters", fontfamily='Arial') ax.set_ylabel(f"{method.capitalize()} Score", fontfamily='Arial') ax.grid(True) # Adjust layout plt.tight_layout() if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches='tight') elif figdir: fig_path = os.path.join(figdir, f'{clustering_algorithm}_{method}_cluster_validation_scores.png') plt.savefig(fig_path, dpi=dpi, bbox_inches='tight') # Show or close figure if show: plt.show() else: plt.close()
[文档] def plot_elbow_curve(cluster_range, scores, score_type, title=None, save_path=None): """ Plot the elbow curve Args: cluster_range: Range of cluster numbers scores: Corresponding scores score_type: Type of score for title and y-axis label title: Figure title, automatically generated if None save_path: Path to save the figure, do not save if None """ if title is None: title = f"The {score_type} Method showing the optimal k" plt.figure(figsize=(6, 4)) plt.plot(cluster_range, scores, 'bx-') plt.xlabel('Number of clusters', fontfamily='Arial') plt.ylabel(score_type, fontfamily='Arial') plt.title(title, fontfamily='Arial') plt.tight_layout() if save_path: plt.savefig(save_path) plt.show()
[文档] def plot_multiple_scores(cluster_range, scores_dict, title=None, save_path=None): """ Plot multiple scoring methods on the same graph Args: cluster_range: Range of cluster numbers scores_dict: Dictionary with scoring method names as keys and score lists as values title: Figure title, automatically generated if None save_path: Path to save the figure, do not save if None """ if title is None: title = "Comparison of different cluster evaluation metrics" plt.figure(figsize=(8, 6)) for i, (score_name, scores) in enumerate(scores_dict.items()): # Normalize scores to be in the same range normalized_scores = (scores - np.min(scores)) / (np.max(scores) - np.min(scores) + 1e-10) # Invert BIC/AIC to make lower is better become higher is better if score_name.lower() in ['bic', 'aic', 'inertia']: normalized_scores = 1 - normalized_scores plt.plot(cluster_range, normalized_scores, 'o-', label=score_name) plt.xlabel('Number of clusters', fontfamily='Arial') plt.ylabel('Normalized score (higher is better)', fontfamily='Arial') plt.title(title, fontfamily='Arial') plt.legend() plt.grid(True, linestyle='--', alpha=0.7) plt.tight_layout() if save_path: plt.savefig(save_path) plt.show()
[文档] def plot_cluster_results( X, labels, centers=None, title=None, feature_names=None, save_path=None, show=False, dpi=600, plot_3d=False, explained_variance=None, # Configurable visual parameters figsize: Optional[Tuple[int, int]] = None, alpha: float = 0.7, marker_size: int = 20, marker: str = 'o', center_marker: str = 'X', center_size: int = 50, center_color: str = 'red', cmap: str = 'tab10', reduction_method: str = 'pca', show_colorbar: bool = True, show_grid: bool = True, grid_alpha: float = 0.3, max_legend_items: int = 10 ): """ Plot scatter plot of clustering results (2D or 3D) Args: X: Input data, shape (n_samples, n_features) labels: Cluster labels, shape (n_samples,) centers: Cluster centers, plotted if not None title: Figure title feature_names: Feature names for x and y axis labels save_path: Path to save the figure, do not save if None show: Whether to display the figure (default False for batch processing) dpi: Image resolution (default 600) plot_3d: Whether to plot 3D scatter plot (default False) explained_variance: Explained variance ratio from PCA (for title) figsize: Figure size as (width, height), default (6, 5) for 2D, (7, 6) for 3D alpha: Transparency of scatter points (0-1, default 0.7) marker_size: Size of scatter points (default 20) marker: Marker style for data points (default 'o') center_marker: Marker style for cluster centers (default 'X') center_size: Size of center markers (default 50) center_color: Color of center markers (default 'red') cmap: Colormap for clusters, 'tab10' is good for discrete categories (default 'tab10') reduction_method: Dimensionality reduction method, 'pca' or 'tsne' (default 'pca') show_colorbar: Whether to show colorbar (default True) show_grid: Whether to show grid (default True) max_legend_items: Maximum number of legend items to show, hide legend if exceeded (default 10) grid_alpha: Transparency of grid lines (default 0.3) """ # Convert to numpy array if input is DataFrame import pandas as pd if isinstance(X, pd.DataFrame): X = X.values if isinstance(labels, pd.Series): labels = labels.values if centers is not None and isinstance(centers, pd.DataFrame): centers = centers.values # Set default figure size based on plot type if figsize is None: figsize = (7, 6) if plot_3d else (6, 5) # Dimensionality reduction if needed n_components = 3 if plot_3d else 2 explained_var = None centers_reduced = centers if X.shape[1] > n_components: if reduction_method.lower() == 'pca': from sklearn.decomposition import PCA reducer = PCA(n_components=n_components) X_reduced = reducer.fit_transform(X) if centers is not None: centers_reduced = reducer.transform(centers) explained_var = reducer.explained_variance_ratio_ elif reduction_method.lower() == 'tsne': from sklearn.manifold import TSNE reducer = TSNE(n_components=n_components, random_state=42, perplexity=min(30, X.shape[0]-1)) X_reduced = reducer.fit_transform(X) # TSNE cannot transform centers directly, set to None centers_reduced = None explained_var = None else: raise ValueError(f"Unknown reduction method: {reduction_method}. Use 'pca' or 'tsne'.") else: X_reduced = X explained_var = explained_variance # Get unique labels for color mapping unique_labels = np.unique(labels) n_clusters = len(unique_labels) # Generate colors using colormap # For discrete colormaps like tab10, use indices directly if cmap in ['tab10', 'tab20', 'Set1', 'Set2', 'Set3', 'Paired', 'Accent']: cmap_obj = plt.cm.get_cmap(cmap) colors = [cmap_obj(i % cmap_obj.N) for i in range(n_clusters)] else: colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, n_clusters)) # Create figure if plot_3d: from mpl_toolkits.mplot3d import Axes3D fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection='3d') # Plot each cluster with different color (first, so centers appear on top) for idx, label in enumerate(unique_labels): mask = labels == label ax.scatter( X_reduced[mask, 0], X_reduced[mask, 1], X_reduced[mask, 2], c=[colors[idx]], label=f'Cluster {label}', alpha=alpha, s=marker_size, marker=marker, zorder=1 ) # zorder=1 means the points are plotted first, so centers appear on top # Plot cluster centers (last, with alpha=1.0 and higher zorder to ensure opaque and on top) if centers_reduced is not None: ax.scatter( centers_reduced[:, 0], centers_reduced[:, 1], centers_reduced[:, 2], c=center_color, marker=center_marker, s=center_size, label='Centers', edgecolors='None', linewidths=1, alpha=1.0, zorder=10 ) # Set labels with labelpad for z-axis to ensure visibility if explained_var is not None and X.shape[1] > 3: ax.set_xlabel(f'PC1 ({explained_var[0]*100:.1f}%)', fontfamily='Arial', labelpad=5) ax.set_ylabel(f'PC2 ({explained_var[1]*100:.1f}%)', fontfamily='Arial', labelpad=5) ax.set_zlabel(f'PC3 ({explained_var[2]*100:.1f}%)', fontfamily='Arial', labelpad=5) elif reduction_method.lower() == 'tsne': ax.set_xlabel('t-SNE 1', fontfamily='Arial', labelpad=5) ax.set_ylabel('t-SNE 2', fontfamily='Arial', labelpad=5) ax.set_zlabel('t-SNE 3', fontfamily='Arial', labelpad=5) else: ax.set_xlabel('Feature 1', fontfamily='Arial', labelpad=5) ax.set_ylabel('Feature 2', fontfamily='Arial', labelpad=5) ax.set_zlabel('Feature 3', fontfamily='Arial', labelpad=5) # Only show legend if number of clusters <= max_legend_items if n_clusters <= max_legend_items: ax.legend(loc='best', fontsize=8) else: # 2D plot fig, ax = plt.subplots(figsize=figsize) # Plot each cluster with different color (first, so centers appear on top) for idx, label in enumerate(unique_labels): mask = labels == label ax.scatter( X_reduced[mask, 0], X_reduced[mask, 1], c=[colors[idx]], label=f'Cluster {label}', marker=marker, alpha=alpha, s=marker_size, zorder=1 ) # Plot cluster centers (last, with alpha=1.0 and higher zorder to ensure opaque and on top) if centers_reduced is not None: ax.scatter( centers_reduced[:, 0], centers_reduced[:, 1], c=center_color, marker=center_marker, s=center_size, label='Centers', edgecolors='None', linewidths=1, alpha=1.0, zorder=10 ) # Set labels if explained_var is not None and X.shape[1] > 2: ax.set_xlabel(f'PC1 ({explained_var[0]*100:.1f}%)', fontfamily='Arial') ax.set_ylabel(f'PC2 ({explained_var[1]*100:.1f}%)', fontfamily='Arial') elif reduction_method.lower() == 'tsne': ax.set_xlabel('t-SNE 1', fontfamily='Arial') ax.set_ylabel('t-SNE 2', fontfamily='Arial') elif feature_names and len(feature_names) >= 2: ax.set_xlabel(feature_names[0], fontfamily='Arial') ax.set_ylabel(feature_names[1], fontfamily='Arial') else: ax.set_xlabel('Feature 1', fontfamily='Arial') ax.set_ylabel('Feature 2', fontfamily='Arial') # Add legend (only if number of clusters <= max_legend_items) and grid if n_clusters <= max_legend_items: ax.legend(loc='best', fontsize=8) if show_grid: ax.grid(True, linestyle='--', alpha=grid_alpha) # Set title if title: plt.title(title, fontfamily='Arial', fontsize=11) else: plt.title(f'Clustering Results (n_clusters={n_clusters})', fontfamily='Arial', fontsize=11) plt.tight_layout() # Save figure if path is provided if save_path: os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True) plt.savefig(save_path, dpi=dpi, bbox_inches='tight') # Show or close figure if show: plt.show() else: plt.close()