from typing import Dict, Any, Tuple, Optional
import numpy as np
import SimpleITK as sitk
import ants
# Optional torch import for tensor conversion methods (not used in core functionality)
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
torch = None
[文档]
class ImageConverter:
"""Utility class for converting between different image formats."""
[文档]
@staticmethod
def tensor_to_numpy(tensor) -> np.ndarray:
"""Convert torch tensor to numpy array.
Args:
tensor: Input tensor in format [C,Z,Y,X] or [C,H,W].
Returns:
np.ndarray: Numpy array with channel dimension removed if single channel.
Raises:
ImportError: If torch is not installed.
"""
if not TORCH_AVAILABLE:
raise ImportError("torch is required for tensor_to_numpy. Install it with: pip install torch")
array = tensor.cpu().numpy()
if array.shape[0] == 1: # If single channel
array = array.squeeze(0) # Remove channel dimension
return array
[文档]
@staticmethod
def numpy_to_tensor(array: np.ndarray, dtype=None, device=None):
"""Convert numpy array to torch tensor.
Args:
array (np.ndarray): Input array in [Z,Y,X] format.
dtype: Target tensor dtype (requires torch).
device: Target tensor device (requires torch).
Returns:
torch.Tensor: Torch tensor with added channel dimension [1,Z,Y,X].
Raises:
ImportError: If torch is not installed.
"""
if not TORCH_AVAILABLE:
raise ImportError("torch is required for numpy_to_tensor. Install it with: pip install torch")
if array.ndim == 2:
array = array[np.newaxis, ...] # Add channel dim for 2D
elif array.ndim == 3:
array = array[np.newaxis, ...] # Add channel dim for 3D
tensor = torch.from_numpy(array)
if dtype is not None or device is not None:
tensor = tensor.to(dtype=dtype, device=device)
return tensor
[文档]
@staticmethod
def ants_2_itk(image):
imageITK = sitk.GetImageFromArray(image.numpy().transpose(2, 1, 0))
imageITK.SetOrigin(image.origin)
imageITK.SetSpacing(image.spacing)
imageITK.SetDirection(image.direction.reshape(9))
return imageITK
[文档]
@staticmethod
def itk_2_ants(image):
image_ants = ants.from_numpy(sitk.GetArrayFromImage(image).transpose(2, 1, 0),
origin=image.GetOrigin(),
spacing=image.GetSpacing(),
direction=np.array(image.GetDirection()).reshape(3, 3))
return image_ants