自定义扩展指南 ============ 本节介绍如何自定义和扩展 HABIT 的各种组件,包括预处理器、特征提取器、聚类算法、模型等。 概述 ---- HABIT 的设计理念之一是"无限扩展",通过工厂模式和注册机制,用户可以轻松添加自定义组件。 **可扩展的组件:** - **预处理器**: 添加自定义的图像预处理方法 - **特征提取器**: 添加自定义的聚类特征提取方法 - **聚类算法**: 添加自定义的聚类算法 - **策略**: 添加自定义的生境分割策略 - **模型**: 添加自定义的机器学习模型 - **特征选择器**: 添加自定义的特征选择方法 **扩展机制:** HABIT 使用工厂模式和注册机制实现扩展: 1. **工厂模式**: 所有可扩展组件都使用工厂模式创建 2. **注册机制**: 通过装饰器注册自定义组件 3. **统一接口**: 所有自定义组件都遵循统一的接口规范 4. **即插即用**: 注册后即可在配置文件中使用 扩展原则 -------- **1. 遵循接口规范** 所有自定义组件都必须继承相应的基类并实现必需的方法: - **预处理器**: 继承 `BasePreprocessor`,实现 `__call__` 方法 - **特征提取器**: 继承 `BaseClusteringExtractor`,实现 `extract_features` 方法 - **聚类算法**: 继承 `BaseClusteringAlgorithm`,实现 `fit_predict` 方法 - **模型**: 继承 `BaseModel`,实现 `fit`、`predict`、`predict_proba` 方法 **2. 使用注册装饰器** 使用相应的注册装饰器注册自定义组件: - **预处理器**: `@PreprocessorFactory.register("name")` - **特征提取器**: `@register_feature_extractor('name')` - **聚类算法**: `@ClusteringFactory.register("name")` - **模型**: `@ModelFactory.register("name")` **3. 提供清晰的文档** 为自定义组件提供清晰的文档说明: - **功能描述**: 描述组件的功能和用途 - **参数说明**: 说明所有参数的含义和默认值 - **使用示例**: 提供使用示例 - **注意事项**: 说明使用时的注意事项 **4. 测试和验证** 对自定义组件进行充分的测试和验证: - **单元测试**: 测试组件的基本功能 - **集成测试**: 测试组件与其他组件的集成 - **性能测试**: 测试组件的性能 - **验证**: 验证组件的正确性 自定义预处理器 ------------ **步骤 1: 创建自定义预处理器** .. code-block:: python from habit.core.preprocessing.preprocessor_factory import PreprocessorFactory from habit.core.preprocessing.base_preprocessor import BasePreprocessor @PreprocessorFactory.register("my_preprocessor") class MyPreprocessor(BasePreprocessor): def __init__(self, keys, allow_missing_keys=False, **kwargs): super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) self.param1 = kwargs.get('param1', default_value) self.param2 = kwargs.get('param2', default_value) def __call__(self, data): self._check_keys(data) for key in self.keys: data[key] = self._process_item(data[key]) return data def _process_item(self, item): # 实现您的预处理逻辑 return processed_item **步骤 2: 在配置文件中使用** .. code-block:: yaml Preprocessing: my_preprocessor: images: [T1, T2] param1: value1 param2: value2 **步骤 3: 运行预处理** .. code-block:: bash habit preprocess --config config_with_custom_preprocessor.yaml **示例: 自定义高斯滤波预处理器** .. code-block:: python import numpy as np from scipy.ndimage import gaussian_filter from habit.core.preprocessing.preprocessor_factory import PreprocessorFactory from habit.core.preprocessing.base_preprocessor import BasePreprocessor @PreprocessorFactory.register("gaussian_filter") class GaussianFilterPreprocessor(BasePreprocessor): def __init__(self, keys, allow_missing_keys=False, **kwargs): super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) self.sigma = kwargs.get('sigma', 1.0) self.order = kwargs.get('order', 0) def __call__(self, data): self._check_keys(data) for key in self.keys: data[key] = self._process_item(data[key]) return data def _process_item(self, item): return gaussian_filter(item, sigma=self.sigma, order=self.order) 自定义特征提取器 ------------ **步骤 1: 创建自定义特征提取器** .. code-block:: python from habit.core.habit_analysis.extractors.base_extractor import BaseClusteringExtractor from habit.core.habit_analysis.extractors.base_extractor import register_feature_extractor @register_feature_extractor('my_feature_extractor') class MyFeatureExtractor(BaseClusteringExtractor): def __init__(self, **kwargs): super().__init__(**kwargs) self.feature_names = ['feature1', 'feature2', 'feature3'] def extract_features(self, image_data, **kwargs): # 实现特征提取逻辑 n_samples = image_data.shape[0] features = np.random.random((n_samples, 3)) return features **步骤 2: 在配置文件中使用** .. code-block:: yaml FeatureConstruction: voxel_level: method: my_feature_extractor(raw(delay2), raw(delay3)) params: param1: value1 **步骤 3: 运行生境分析** .. code-block:: bash habit get-habitat --config config_with_custom_extractor.yaml **示例: 自定义局部对比度特征提取器** .. code-block:: python import numpy as np from habit.core.habit_analysis.extractors.base_extractor import BaseClusteringExtractor from habit.core.habit_analysis.extractors.base_extractor import register_feature_extractor @register_feature_extractor('local_contrast') class LocalContrastExtractor(BaseClusteringExtractor): def __init__(self, **kwargs): super().__init__(**kwargs) self.radius = kwargs.get('radius', 3) self.feature_names = ['local_contrast'] def extract_features(self, image_data, **kwargs): n_samples = image_data.shape[0] features = np.zeros((n_samples, 1)) for i in range(n_samples): features[i, 0] = self._compute_local_contrast(image_data[i]) return features def _compute_local_contrast(self, image): local_mean = self._compute_local_mean(image) local_contrast = np.abs(image - local_mean) return local_contrast def _compute_local_mean(self, image): from scipy.ndimage import uniform_filter return uniform_filter(image, size=self.radius * 2 + 1) 自定义聚类算法 ------------ **步骤 1: 创建自定义聚类算法** .. code-block:: python from habit.core.habit_analysis.clustering.base_clustering import BaseClusteringAlgorithm from habit.core.habit_analysis.clustering.clustering_factory import ClusteringFactory @ClusteringFactory.register("my_clustering") class MyClusteringAlgorithm(BaseClusteringAlgorithm): def __init__(self, n_clusters=3, random_state=None, **kwargs): super().__init__(n_clusters=n_clusters, random_state=random_state, **kwargs) self.param1 = kwargs.get('param1', default_value) def fit_predict(self, X, **kwargs): # 实现聚类逻辑 labels = self._cluster(X) return labels def _cluster(self, X): # 实现具体的聚类算法 return labels **步骤 2: 在配置文件中使用** .. code-block:: yaml HabitatsSegmention: clustering_mode: two_step supervoxel: algorithm: my_clustering n_clusters: 50 param1: value1 **步骤 3: 运行生境分析** .. code-block:: bash habit get-habitat --config config_with_custom_clustering.yaml **示例: 自定义谱聚类算法** .. code-block:: python import numpy as np from sklearn.cluster import SpectralClustering from habit.core.habit_analysis.clustering.base_clustering import BaseClusteringAlgorithm from habit.core.habit_analysis.clustering.clustering_factory import ClusteringFactory @ClusteringFactory.register("spectral") class SpectralClusteringAlgorithm(BaseClusteringAlgorithm): def __init__(self, n_clusters=3, random_state=None, **kwargs): super().__init__(n_clusters=n_clusters, random_state=random_state, **kwargs) self.gamma = kwargs.get('gamma', 1.0) self.n_neighbors = kwargs.get('n_neighbors', 10) def fit_predict(self, X, **kwargs): clustering = SpectralClustering( n_clusters=self.n_clusters, gamma=self.gamma, n_neighbors=self.n_neighbors, random_state=self.random_state ) labels = clustering.fit_predict(X) return labels 自定义模型 ------------ **步骤 1: 创建自定义模型** .. code-block:: python from habit.core.machine_learning.models.base import BaseModel from habit.core.machine_learning.models.factory import ModelFactory @ModelFactory.register("my_model") class MyModel(BaseModel): def __init__(self, **kwargs): super().__init__(**kwargs) self.param1 = kwargs.get('param1', default_value) self.model = None def fit(self, X, y, **kwargs): # 实现模型训练逻辑 self.model = self._train(X, y) return self def predict(self, X, **kwargs): # 实现预测逻辑 return self.model.predict(X) def predict_proba(self, X, **kwargs): # 实现概率预测逻辑 return self.model.predict_proba(X) def _train(self, X, y): # 实现具体的训练算法 return model **步骤 2: 在配置文件中使用** .. code-block:: yaml models: my_model: params: param1: value1 **步骤 3: 运行机器学习** .. code-block:: bash habit model --config config_with_custom_model.yaml **示例: 自定义神经网络模型** .. code-block:: python import numpy as np from sklearn.neural_network import MLPClassifier from habit.core.machine_learning.models.base import BaseModel from habit.core.machine_learning.models.factory import ModelFactory @ModelFactory.register("neural_network") class NeuralNetworkModel(BaseModel): def __init__(self, **kwargs): super().__init__(**kwargs) self.hidden_layer_sizes = kwargs.get('hidden_layer_sizes', (100,)) self.activation = kwargs.get('activation', 'relu') self.solver = kwargs.get('solver', 'adam') self.max_iter = kwargs.get('max_iter', 200) self.random_state = kwargs.get('random_state', None) self.model = None def fit(self, X, y, **kwargs): self.model = MLPClassifier( hidden_layer_sizes=self.hidden_layer_sizes, activation=self.activation, solver=self.solver, max_iter=self.max_iter, random_state=self.random_state ) self.model.fit(X, y) return self def predict(self, X, **kwargs): return self.model.predict(X) def predict_proba(self, X, **kwargs): return self.model.predict_proba(X) 自定义特征选择器 ------------ **步骤 1: 创建自定义特征选择器** .. code-block:: python from sklearn.base import BaseEstimator, TransformerMixin from habit.core.machine_learning.feature_selectors.selector_registry import register_selector @register_selector('my_selector') class MyFeatureSelector(BaseEstimator, TransformerMixin): def __init__(self, param1=default_value, param2=default_value): self.param1 = param1 self.param2 = param2 self.selected_features_ = None def fit(self, X, y=None): # 实现特征选择逻辑 self.selected_features_ = self._select_features(X, y) return self def transform(self, X): # 实现特征转换逻辑 return X[:, self.selected_features_] def _select_features(self, X, y): # 实现具体的特征选择算法 return selected_indices **步骤 2: 在配置文件中使用** .. code-block:: yaml feature_selection_methods: - method: my_selector params: param1: value1 param2: value2 **步骤 3: 运行机器学习** .. code-block:: bash habit model --config config_with_custom_selector.yaml **示例: 自定义互信息特征选择器** .. code-block:: python import numpy as np from sklearn.feature_selection import mutual_info_classif from sklearn.base import BaseEstimator, TransformerMixin from habit.core.machine_learning.feature_selectors.selector_registry import register_selector @register_selector('mutual_info') class MutualInfoSelector(BaseEstimator, TransformerMixin): def __init__(self, k_features=10, random_state=None): self.k_features = k_features self.random_state = random_state self.selected_features_ = None self.scores_ = None def fit(self, X, y): scores = mutual_info_classif(X, y, random_state=self.random_state) self.scores_ = scores self.selected_features_ = np.argsort(scores)[-self.k_features:] return self def transform(self, X): return X[:, self.selected_features_] 最佳实践 -------- **1. 命名规范** - 使用清晰的、描述性的名称 - 使用小写字母和下划线 - 避免使用缩写 **示例:** .. code-block:: python # 好的命名 @PreprocessorFactory.register("gaussian_filter") @register_feature_extractor('local_contrast') @ClusteringFactory.register("spectral") # 不好的命名 @PreprocessorFactory.register("gf") @register_feature_extractor('lc') @ClusteringFactory.register("spec") **2. 参数验证** 对输入参数进行验证,确保参数的有效性。 **示例:** .. code-block:: python def __init__(self, sigma=1.0, **kwargs): super().__init__(**kwargs) if sigma <= 0: raise ValueError("sigma must be positive") self.sigma = sigma **3. 文档字符串** 为自定义组件提供清晰的文档字符串。 **示例:** .. code-block:: python @PreprocessorFactory.register("gaussian_filter") class GaussianFilterPreprocessor(BasePreprocessor): """ 高斯滤波预处理器。 对图像应用高斯滤波,平滑图像并减少噪声。 参数 ---------- sigma : float, default=1.0 高斯核的标准差。值越大,平滑效果越强。 order : int, default=0 高斯滤波的阶数。0 表示平滑,1 表示一阶导数,2 表示二阶导数。 注意事项 ---------- - 高斯滤波会模糊图像细节 - 较大的 sigma 值会导致更强的平滑效果 """ def __init__(self, keys, allow_missing_keys=False, **kwargs): super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) self.sigma = kwargs.get('sigma', 1.0) self.order = kwargs.get('order', 0) **4. 错误处理** 提供清晰的错误信息,便于调试。 **示例:** .. code-block:: python def __call__(self, data): self._check_keys(data) for key in self.keys: try: data[key] = self._process_item(data[key]) except Exception as e: raise RuntimeError(f"Failed to process {key}: {str(e)}") return data **5. 测试** 为自定义组件编写测试,确保功能的正确性。 **示例:** .. code-block:: python import unittest import numpy as np class TestGaussianFilterPreprocessor(unittest.TestCase): def setUp(self): self.preprocessor = GaussianFilterPreprocessor( keys=['image'], sigma=1.0 ) def test_gaussian_filter(self): data = {'image': np.random.random((10, 10, 10))} result = self.preprocessor(data) self.assertIn('image', result) self.assertEqual(result['image'].shape, (10, 10, 10)) if __name__ == '__main__': unittest.main() 常见问题 -------- **Q1: 如何调试自定义组件?** A: 可以使用以下方法: 1. 使用 `debug` 模式启用详细日志 2. 在代码中添加 `print` 语句 3. 使用 Python 调试器(pdb) 4. 编写单元测试 **Q2: 如何分享自定义组件?** A: 可以通过以下方式分享: 1. 将代码分享给其他研究者 2. 创建 GitHub 仓库 3. 提交到 HABIT 项目 4. 编写文档和示例 **Q3: 如何优化自定义组件的性能?** A: 可以尝试以下方法: 1. 使用向量化操作 2. 使用并行计算 3. 使用 C/C++ 扩展 4. 优化算法 **Q4: 如何确保自定义组件的正确性?** A: 可以通过以下方法验证: 1. 编写单元测试 2. 与已知结果对比 3. 使用可视化验证 4. 进行交叉验证 **Q5: 如何处理自定义组件的依赖?** A: 可以通过以下方式处理: 1. 在文档中说明依赖 2. 提供安装说明 3. 使用虚拟环境 4. 提供依赖文件(requirements.txt) 下一步 ------- 自定义扩展完成后,您可以: - :doc:`../configuration_zh`: 了解配置文件的详细说明 - :doc:`../design_philosophy_zh`: 了解 HABIT 的设计哲学 - :doc:`../cli_zh`: 了解 CLI 命令的详细说明