habit.core.preprocessing.n4_correction 源代码

from typing import Dict, Any, Optional, Union, List, Tuple, Sequence
import numpy as np
import SimpleITK as sitk
import logging
from .base_preprocessor import BasePreprocessor
from .preprocessor_factory import PreprocessorFactory
from ...utils.image_converter import ImageConverter
from ...utils.log_utils import get_module_logger

# Get module logger
logger = get_module_logger(__name__)

[文档] @PreprocessorFactory.register("n4_correction") class N4BiasFieldCorrection(BasePreprocessor): """Apply N4 bias field correction to images using SimpleITK. This preprocessor applies N4 bias field correction to correct for intensity inhomogeneity in medical images. """
[文档] def __init__( self, keys: Union[str, List[str]], mask_keys: Optional[Union[str, List[str]]] = None, num_fitting_levels: int = 4, num_iterations: List[int] = None, convergence_threshold: float = 0.001, shrink_factor: int = 4, allow_missing_keys: bool = False, **kwargs ): """Initialize the N4 bias field correction preprocessor. Args: keys (Union[str, List[str]]): Keys of the images to be corrected. mask_keys (Optional[Union[str, List[str]]]): Keys of the masks to use for correction. If None, no mask will be used. num_fitting_levels (int): Number of fitting levels for the bias field correction. num_iterations (List[int]): Number of iterations at each fitting level. If None, will use [50] * num_fitting_levels. convergence_threshold (float): Convergence threshold for the correction. shrink_factor (int): Shrink factor to accelerate computation (default 4). allow_missing_keys (bool): If True, allows missing keys in the input data. **kwargs: Additional parameters for N4 correction. """ super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) # Convert single key to list if isinstance(keys, str): keys = [keys] self.keys = keys # Handle mask keys if mask_keys is None: self.mask_keys = None else: self.mask_keys = [mask_keys] if isinstance(mask_keys, str) else mask_keys # Set N4 parameters self.num_fitting_levels = num_fitting_levels self.num_iterations = num_iterations if num_iterations is not None else [50] * num_fitting_levels self.convergence_threshold = convergence_threshold self.shrink_factor = shrink_factor
def _apply_n4_correction(self, sitk_image: sitk.Image, sitk_mask: Optional[sitk.Image] = None, subj: Optional[str] = None) -> sitk.Image: """Apply N4 bias field correction to a SimpleITK image. Args: sitk_image (sitk.Image): Input SimpleITK image to correct sitk_mask (Optional[sitk.Image]): Optional mask for the correction subj (Optional[str]): Subject identifier for logging Returns: sitk.Image: Corrected SimpleITK image """ subj_info = f"[{subj}] " if subj else "" # Cast image to float32 sitk_image = sitk.Cast(sitk_image, sitk.sitkFloat32) # Create original image copy for full resolution correction original_image = sitk_image # Apply shrinking to speed up computation if shrink_factor > 1 if self.shrink_factor > 1: logger.debug(f"{subj_info}Applying shrinking with factor {self.shrink_factor}") sitk_image = sitk.Shrink(original_image, [self.shrink_factor] * original_image.GetDimension()) if sitk_mask is not None: sitk_mask = sitk.Shrink(sitk_mask, [self.shrink_factor] * original_image.GetDimension()) # Create and configure N4 corrector corrector = sitk.N4BiasFieldCorrectionImageFilter() corrector.SetMaximumNumberOfIterations(self.num_iterations) corrector.SetConvergenceThreshold(self.convergence_threshold) # Execute the correction if sitk_mask is not None: corrected_image = corrector.Execute(sitk_image, sitk_mask) else: corrected_image = corrector.Execute(sitk_image) # If we used shrinking, apply correction to the full resolution image if self.shrink_factor > 1: logger.debug(f"{subj_info}Applying correction to full resolution") # Get the log bias field and apply to full resolution image log_bias_field = corrector.GetLogBiasFieldAsImage(original_image) corrected_image = original_image / sitk.Exp(log_bias_field) return corrected_image
[文档] def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Apply N4 bias field correction to the specified images. Args: data (Dict[str, Any]): Input data dictionary containing SimpleITK image objects. Returns: Dict[str, Any]: Data dictionary with corrected images. """ self._check_keys(data) subj = data.get('subj', 'unknown') logger.debug(f"[{subj}] Applying N4 bias field correction") # Process each image for key in self.keys: # Get SimpleITK image from data sitk_image = data[key] # Ensure we have a SimpleITK image object if not isinstance(sitk_image, sitk.Image): logger.warning(f"[{subj}] Warning: {key} is not a SimpleITK Image object. Skipping.") continue # Get corresponding mask if specified sitk_mask = None if self.mask_keys is not None: mask_key = f"mask_{key}" if mask_key in data: sitk_mask = data[mask_key] if not isinstance(sitk_mask, sitk.Image): logger.warning(f"[{subj}] Warning: {mask_key} is not a SimpleITK Image object. Using no mask.") sitk_mask = None try: # Apply N4 correction corrected_image = self._apply_n4_correction(sitk_image, sitk_mask, subj) # Store the corrected image data[key] = corrected_image # Update metadata meta_key = f"{key}_meta_dict" if meta_key not in data: data[meta_key] = {} data[meta_key]["n4_corrected"] = True except Exception as e: logger.error(f"[{subj}] Error applying N4 correction to {key}: {e}") if not self.allow_missing_keys: raise return data