habit.core.preprocessing.registration 源代码

from typing import Dict, Any, Optional, Union, List, Tuple, Sequence
import numpy as np
import SimpleITK as sitk
import ants
from habit.utils.image_converter import ImageConverter
from .base_preprocessor import BasePreprocessor
from .preprocessor_factory import PreprocessorFactory

from ...utils.log_utils import get_module_logger

# Get module logger
logger = get_module_logger(__name__)

[文档] @PreprocessorFactory.register("registration") class RegistrationPreprocessor(BasePreprocessor): """Register images to a reference image using ANTs. This preprocessor performs image registration using ANTs (Advanced Normalization Tools). It supports various registration methods including SyN, Rigid, Affine, etc. """
[文档] def __init__( self, keys: Union[str, List[str]], fixed_image: str, mask_keys: Optional[Union[str, List[str]]] = None, type_of_transform: str = "SyN", metric: str = "MI", optimizer: Optional[str] = None, use_mask: bool = False, allow_missing_keys: bool = False, replace_by_fixed_image_mask: bool = True, **kwargs ): """Initialize the registration preprocessor. Args: keys (Union[str, List[str]]): Keys of the images to be registered. fixed_image (str): Key of the reference image to register to. mask_keys (Optional[Union[str, List[str]]]): Keys of the masks to use for registration. type_of_transform (str): Type of transform to use. Options: - "Rigid": Rigid transformation - "Affine": Affine transformation - "SyN": Symmetric normalization (deformable) - "SyNRA": SyN + Rigid + Affine - "SyNOnly": SyN without initial rigid/affine - "TRSAA": Translation + Rotation + Scaling + Affine - "Elastic": Elastic transformation - "SyNCC": SyN with cross-correlation metric - "SyNabp": SyN with mutual information metric - "SyNBold": SyN optimized for BOLD images - "SyNBoldAff": SyN + Affine for BOLD images - "SyNAggro": SyN with aggressive optimization - "TVMSQ": Time-varying diffeomorphism with mean square metric metric (str): Similarity metric to use. Options: - "CC": Cross-correlation - "MI": Mutual information - "MeanSquares": Mean squares - "Demons": Demons metric optimizer (str): Optimizer to use. Options: - "gradient_descent": Gradient descent - "lbfgsb": L-BFGS-B - "amoeba": Amoeba use_mask (bool): If True, use mask for registration. allow_missing_keys (bool): If True, allows missing keys in the input data. replace_by_fixed_image_mask (bool): If True, use fixed image's mask to replace moving image's mask after registration. **kwargs: Additional parameters for registration. """ super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) # Convert single key to list if isinstance(keys, str): keys = [keys] self.keys = keys self.fixed_image = fixed_image # Handle mask keys if mask_keys is None: self.mask_keys = None else: self.mask_keys = [mask_keys] if isinstance(mask_keys, str) else mask_keys # Set registration parameters self.type_of_transform = type_of_transform self.metric = metric self.optimizer = optimizer self.use_mask = use_mask self.replace_by_fixed_image_mask = replace_by_fixed_image_mask # Store additional parameters self.reg_params = kwargs
def _register_image(self, fixed_image: ants.ANTsImage, moving_image: ants.ANTsImage, fixed_mask: Optional[ants.ANTsImage] = None, moving_mask: Optional[ants.ANTsImage] = None) -> Tuple[ants.ANTsImage, List[str]]: """Register a moving image to a fixed image using ANTs. Args: fixed_image (ants.ANTsImage): Reference image moving_image (ants.ANTsImage): Image to be registered fixed_mask (Optional[ants.ANTsImage]): Mask for reference image moving_mask (Optional[ants.ANTsImage]): Mask for moving image Returns: Tuple[ants.ANTsImage, List[str]]: - Registered image - List of transform files """ # Prepare registration parameters reg_params = { 'metric': self.metric, 'optimizer': self.optimizer, **self.reg_params } # Add masks if provided if fixed_mask is not None: reg_params['mask'] = fixed_mask if moving_mask is not None: reg_params['moving_mask'] = moving_mask # Perform registration reg_result = ants.registration( fixed=fixed_image, moving=moving_image, type_of_transform=self.type_of_transform, **reg_params ) return reg_result['warpedmovout'], reg_result['fwdtransforms']
[文档] def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Register the specified images to the reference image. Args: data (Dict[str, Any]): Input data dictionary containing ANTs image objects. Returns: Dict[str, Any]: Data dictionary with registered images. """ subj = data.get('subj', 'unknown') logger.debug(f"[{subj}] Registering to {self.fixed_image}") self._check_keys(data) # Get reference image if self.fixed_image not in data: raise KeyError(f"Reference key {self.fixed_image} not found in data dictionary") # Get and convert fixed image (only once, outside the loop) fixed_image_sitk = data[self.fixed_image] fixed_image_sitk = sitk.Cast(fixed_image_sitk, sitk.sitkFloat32) fixed_image_ants = ImageConverter.itk_2_ants(fixed_image_sitk) # Get reference mask if specified (only once, outside the loop) fixed_mask_ants = None if self.use_mask: mask_key = f"mask_{self.fixed_image}" if mask_key in data: fixed_mask = data[mask_key] fixed_mask = sitk.Cast(fixed_mask, sitk.sitkUInt8) fixed_mask_ants = ImageConverter.itk_2_ants(fixed_mask) # Process each image for key in self.keys: if key == self.fixed_image: continue # Get moving image moving_image = data[key] moving_image = sitk.Cast(moving_image, sitk.sitkFloat32) moving_image = ImageConverter.itk_2_ants(moving_image) # Get moving mask if specified moving_mask = None if self.use_mask: mask_key = f"mask_{key}" if mask_key in data: moving_mask = data[mask_key] moving_mask = sitk.Cast(moving_mask, sitk.sitkUInt8) moving_mask = ImageConverter.itk_2_ants(moving_mask) try: # Register image registered_image, transform_files = self._register_image( fixed_image_ants, moving_image, fixed_mask_ants, moving_mask ) # Convert ANTs image to SimpleITK image registered_sitk = ImageConverter.ants_2_itk(registered_image) # sitk.GetArrayFromImage(registered_sitk) # Store the registered image data[key] = registered_sitk # Store the transform files transform_key = f"{key}_transform_files" data[transform_key] = transform_files # Update metadata meta_key = f"{key}_meta_dict" if meta_key not in data: data[meta_key] = {} data[meta_key]["registered"] = True data[meta_key]["fixed_image"] = self.fixed_image data[meta_key]["type_of_transform"] = self.type_of_transform data[meta_key]["metric"] = self.metric data[meta_key]["optimizer"] = self.optimizer # Update the image path to indicate it's registered # data[meta_key]["image_path"] = data[meta_key]["image_path"].replace(".nii.gz", "_registered.nii.gz") except Exception as e: logger.error(f"Error registering image {key}: {e}") if not self.allow_missing_keys: raise # ============================ # Process each mask image for key in self.keys: if key == self.fixed_image: continue mask_key = f"mask_{key}" fixed_mask_key = f"mask_{self.fixed_image}" transform_key = f"{key}_transform_files" # Skip if no mask for moving image if mask_key not in data: continue # Skip if no mask for fixed image and replace option is enabled if self.replace_by_fixed_image_mask and fixed_mask_key not in data: logger.warning(f"Warning: Cannot replace mask for {key} because fixed mask {fixed_mask_key} not found.") continue # If user chose to replace moving mask with fixed mask if self.replace_by_fixed_image_mask: logger.debug(f"Replacing mask for {key} with fixed image mask") # Get the fixed mask and make a copy for the moving image fixed_mask = data[fixed_mask_key] data[mask_key] = sitk.Cast(fixed_mask, sitk.sitkUInt8) # Update metadata meta_key = f"{mask_key}_meta_dict" if meta_key not in data: data[meta_key] = {} data[meta_key]["registered"] = True data[meta_key]["fixed_image"] = self.fixed_image data[meta_key]["replaced_by_fixed_mask"] = True continue # Normal mask registration process # Skip if no transform files (which means image registration failed) if transform_key not in data: logger.warning(f"Warning: No transform files found for {key}. Skipping mask registration.") continue # Get the mask image and convert to ANTs moving_mask = data[mask_key] moving_mask = sitk.Cast(moving_mask, sitk.sitkUInt8) moving_mask_ants = ImageConverter.itk_2_ants(moving_mask) # Get the transform files from previous registration transform_files = data[transform_key] try: # Apply the transform to the mask (reuse fixed_image_ants from above) transformed_mask = ants.apply_transforms( fixed=fixed_image_ants, moving=moving_mask_ants, transformlist=transform_files, interpolator="nearestNeighbor" # Use nearest neighbor for masks ) # Convert back to SimpleITK transformed_mask_sitk = ImageConverter.ants_2_itk(transformed_mask) # Ensure it's binary (uint8) transformed_mask_sitk = sitk.Cast(transformed_mask_sitk, sitk.sitkUInt8) # Store the transformed mask data[mask_key] = transformed_mask_sitk # Update metadata meta_key = f"{mask_key}_meta_dict" if meta_key not in data: data[meta_key] = {} data[meta_key]["registered"] = True data[meta_key]["fixed_image"] = self.fixed_image data[meta_key]["type_of_transform"] = self.type_of_transform data[meta_key]["metric"] = self.metric data[meta_key]["optimizer"] = self.optimizer except Exception as e: logger.error(f"Error applying transform to mask {mask_key}: {e}")
# Continue even if error occurs for one mask