from typing import Dict, Any, Optional, Union, List, Tuple
import SimpleITK as sitk
import numpy as np
from .base_preprocessor import BasePreprocessor
from .preprocessor_factory import PreprocessorFactory
from ...utils.log_utils import get_module_logger
# Get module logger
logger = get_module_logger(__name__)
[文档]
@PreprocessorFactory.register("zscore_normalization")
class ZScoreNormalization(BasePreprocessor):
"""Apply Z-score normalization to medical images.
This preprocessor normalizes image intensities by subtracting the mean and
dividing by the standard deviation, resulting in a distribution with
zero mean and unit variance.
"""
[文档]
def __init__(
self,
keys: Union[str, List[str]],
only_inmask: bool = False,
mask_key: Optional[str] = None,
clip_values: Optional[Tuple[float, float]] = None,
allow_missing_keys: bool = False,
**kwargs
):
"""Initialize the Z-score normalization preprocessor.
Args:
keys (Union[str, List[str]]): Keys of the images to be normalized.
only_inmask (bool): If True, only calculate statistics within the mask.
mask_key (Optional[str]): Key of the mask to use when only_inmask is True.
clip_values (Optional[Tuple[float, float]]): Optional tuple of (min, max) values to clip
normalized results. Useful to prevent extreme values, e.g. (-3, 3).
allow_missing_keys (bool): If True, allows missing keys in the input data.
**kwargs: Additional parameters.
"""
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
# Convert single key to list
if isinstance(keys, str):
keys = [keys]
self.keys = keys
# Handle mask settings
self.only_inmask = only_inmask
self.mask_key = mask_key
# Clipping values
self.clip_values = clip_values
def _apply_zscore_normalization(
self,
sitk_image: sitk.Image,
sitk_mask: Optional[sitk.Image] = None,
subj: Optional[str] = None
) -> sitk.Image:
"""Apply Z-score normalization to a SimpleITK image.
Args:
sitk_image (sitk.Image): Input SimpleITK image to normalize
sitk_mask (Optional[sitk.Image]): Optional mask for computing statistics
subj (Optional[str]): Subject identifier for logging
Returns:
sitk.Image: Normalized SimpleITK image
"""
# Calculate statistics using SimpleITK
stats_filter = sitk.StatisticsImageFilter()
# If mask is provided, use it for statistics calculation
if sitk_mask is not None:
# Create a version of the image with mask applied
masked_image = sitk.Mask(sitk_image, sitk_mask)
stats_filter.Execute(masked_image)
else:
stats_filter.Execute(sitk_image)
# Get mean and standard deviation
mean_val = stats_filter.GetMean()
std_val = stats_filter.GetSigma() # Use GetSigma() instead of GetStandardDeviation()
subj_info = f"[{subj}] " if subj else ""
logger.debug(f"{subj_info}Mean: {mean_val}, std: {std_val}")
# Avoid division by zero or very small values
if std_val < 1e-10:
logger.warning(f"{subj_info}Warning: Standard deviation is very small ({std_val}). Using std=1 to avoid division issues.")
std_val = 1.0
# Create mean image (same size as input, all pixels = mean value)
mean_image = sitk.Image(sitk_image.GetSize(), sitk_image.GetPixelID())
mean_image.CopyInformation(sitk_image) # Copy metadata
mean_image = sitk.Add(mean_image, mean_val) # Fill with mean value
# Subtract mean (step 1 of z-score)
centered_image = sitk.Subtract(sitk_image, mean_image)
# Divide by standard deviation (step 2 of z-score)
normalized_image = sitk.Divide(centered_image, std_val)
# Get sample values for logging
sample_array = sitk.GetArrayFromImage(normalized_image)
logger.debug(f"{subj_info}Normalized range: [{np.min(sample_array)}, {np.max(sample_array)}]")
# Clip values if specified
if self.clip_values is not None:
# Create threshold filter
threshold_filter = sitk.ClampImageFilter()
threshold_filter.SetLowerBound(self.clip_values[0])
threshold_filter.SetUpperBound(self.clip_values[1])
normalized_image = threshold_filter.Execute(normalized_image)
return normalized_image
[文档]
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Apply Z-score normalization to the specified images.
Args:
data (Dict[str, Any]): Input data dictionary containing SimpleITK image objects.
Returns:
Dict[str, Any]: Data dictionary with normalized images.
"""
self._check_keys(data)
subj = data.get('subj', 'unknown')
logger.debug(f"[{subj}] Z-score normalization")
# Process each image
for key in self.keys:
# Get SimpleITK image from data
sitk_image = data[key]
# Ensure we have a SimpleITK image object
if not isinstance(sitk_image, sitk.Image):
logger.warning(f"[{subj}] Warning: {key} is not a SimpleITK Image object. Skipping.")
continue
# Get mask if specified and only_inmask is True
sitk_mask = None
if self.only_inmask and self.mask_key is not None:
if self.mask_key in data:
sitk_mask = data[self.mask_key]
if not isinstance(sitk_mask, sitk.Image):
logger.warning(f"[{subj}] Warning: {self.mask_key} is not a SimpleITK Image object. Using no mask.")
sitk_mask = None
try:
# Apply Z-score normalization
normalized_image = self._apply_zscore_normalization(sitk_image, sitk_mask, subj)
# Verification details moved to debug level
normalized_array = sitk.GetArrayFromImage(normalized_image)
logger.debug(f"[{subj}] Normalized range: [{np.min(normalized_array):.2f}, {np.max(normalized_array):.2f}]")
# Store the normalized image
data[key] = normalized_image
# Update metadata
meta_key = f"{key}_meta_dict"
if meta_key not in data:
data[meta_key] = {}
data[meta_key]["zscore_normalized"] = True
except Exception as e:
logger.error(f"[{subj}] Error applying Z-score normalization to {key}: {e}")
if not self.allow_missing_keys:
raise
return data