自定义扩展指南

本节介绍如何自定义和扩展 HABIT 的各种组件,包括预处理器、特征提取器、聚类算法、模型等。

概述

HABIT 的设计理念之一是"无限扩展",通过工厂模式和注册机制,用户可以轻松添加自定义组件。

可扩展的组件:

  • 预处理器: 添加自定义的图像预处理方法

  • 特征提取器: 添加自定义的聚类特征提取方法

  • 聚类算法: 添加自定义的聚类算法

  • 策略: 添加自定义的生境分割策略

  • 模型: 添加自定义的机器学习模型

  • 特征选择器: 添加自定义的特征选择方法

扩展机制:

HABIT 使用工厂模式和注册机制实现扩展:

  1. 工厂模式: 所有可扩展组件都使用工厂模式创建

  2. 注册机制: 通过装饰器注册自定义组件

  3. 统一接口: 所有自定义组件都遵循统一的接口规范

  4. 即插即用: 注册后即可在配置文件中使用

扩展原则

1. 遵循接口规范

所有自定义组件都必须继承相应的基类并实现必需的方法:

  • 预处理器: 继承 BasePreprocessor,实现 __call__ 方法

  • 特征提取器: 继承 BaseClusteringExtractor,实现 extract_features 方法

  • 聚类算法: 继承 BaseClusteringAlgorithm,实现 fit_predict 方法

  • 模型: 继承 BaseModel,实现 fitpredictpredict_proba 方法

2. 使用注册装饰器

使用相应的注册装饰器注册自定义组件:

  • 预处理器: @PreprocessorFactory.register("name")

  • 特征提取器: @register_feature_extractor('name')

  • 聚类算法: @ClusteringFactory.register("name")

  • 模型: @ModelFactory.register("name")

3. 提供清晰的文档

为自定义组件提供清晰的文档说明:

  • 功能描述: 描述组件的功能和用途

  • 参数说明: 说明所有参数的含义和默认值

  • 使用示例: 提供使用示例

  • 注意事项: 说明使用时的注意事项

4. 测试和验证

对自定义组件进行充分的测试和验证:

  • 单元测试: 测试组件的基本功能

  • 集成测试: 测试组件与其他组件的集成

  • 性能测试: 测试组件的性能

  • 验证: 验证组件的正确性

自定义预处理器

步骤 1: 创建自定义预处理器

from habit.core.preprocessing.preprocessor_factory import PreprocessorFactory
from habit.core.preprocessing.base_preprocessor import BasePreprocessor

@PreprocessorFactory.register("my_preprocessor")
class MyPreprocessor(BasePreprocessor):
    def __init__(self, keys, allow_missing_keys=False, **kwargs):
        super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
        self.param1 = kwargs.get('param1', default_value)
        self.param2 = kwargs.get('param2', default_value)

    def __call__(self, data):
        self._check_keys(data)
        for key in self.keys:
            data[key] = self._process_item(data[key])
        return data

    def _process_item(self, item):
        # 实现您的预处理逻辑
        return processed_item

步骤 2: 在配置文件中使用

Preprocessing:
  my_preprocessor:
    images: [T1, T2]
    param1: value1
    param2: value2

步骤 3: 运行预处理

habit preprocess --config config_with_custom_preprocessor.yaml

示例: 自定义高斯滤波预处理器

import numpy as np
from scipy.ndimage import gaussian_filter
from habit.core.preprocessing.preprocessor_factory import PreprocessorFactory
from habit.core.preprocessing.base_preprocessor import BasePreprocessor

@PreprocessorFactory.register("gaussian_filter")
class GaussianFilterPreprocessor(BasePreprocessor):
    def __init__(self, keys, allow_missing_keys=False, **kwargs):
        super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
        self.sigma = kwargs.get('sigma', 1.0)
        self.order = kwargs.get('order', 0)

    def __call__(self, data):
        self._check_keys(data)
        for key in self.keys:
            data[key] = self._process_item(data[key])
        return data

    def _process_item(self, item):
        return gaussian_filter(item, sigma=self.sigma, order=self.order)

自定义特征提取器

步骤 1: 创建自定义特征提取器

from habit.core.habit_analysis.extractors.base_extractor import BaseClusteringExtractor
from habit.core.habit_analysis.extractors.base_extractor import register_feature_extractor

@register_feature_extractor('my_feature_extractor')
class MyFeatureExtractor(BaseClusteringExtractor):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.feature_names = ['feature1', 'feature2', 'feature3']

    def extract_features(self, image_data, **kwargs):
        # 实现特征提取逻辑
        n_samples = image_data.shape[0]
        features = np.random.random((n_samples, 3))
        return features

步骤 2: 在配置文件中使用

FeatureConstruction:
  voxel_level:
    method: my_feature_extractor(raw(delay2), raw(delay3))
    params:
      param1: value1

步骤 3: 运行生境分析

habit get-habitat --config config_with_custom_extractor.yaml

示例: 自定义局部对比度特征提取器

import numpy as np
from habit.core.habit_analysis.extractors.base_extractor import BaseClusteringExtractor
from habit.core.habit_analysis.extractors.base_extractor import register_feature_extractor

@register_feature_extractor('local_contrast')
class LocalContrastExtractor(BaseClusteringExtractor):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.radius = kwargs.get('radius', 3)
        self.feature_names = ['local_contrast']

    def extract_features(self, image_data, **kwargs):
        n_samples = image_data.shape[0]
        features = np.zeros((n_samples, 1))
        for i in range(n_samples):
            features[i, 0] = self._compute_local_contrast(image_data[i])
        return features

    def _compute_local_contrast(self, image):
        local_mean = self._compute_local_mean(image)
        local_contrast = np.abs(image - local_mean)
        return local_contrast

    def _compute_local_mean(self, image):
        from scipy.ndimage import uniform_filter
        return uniform_filter(image, size=self.radius * 2 + 1)

自定义聚类算法

步骤 1: 创建自定义聚类算法

from habit.core.habit_analysis.clustering.base_clustering import BaseClusteringAlgorithm
from habit.core.habit_analysis.clustering.clustering_factory import ClusteringFactory

@ClusteringFactory.register("my_clustering")
class MyClusteringAlgorithm(BaseClusteringAlgorithm):
    def __init__(self, n_clusters=3, random_state=None, **kwargs):
        super().__init__(n_clusters=n_clusters, random_state=random_state, **kwargs)
        self.param1 = kwargs.get('param1', default_value)

    def fit_predict(self, X, **kwargs):
        # 实现聚类逻辑
        labels = self._cluster(X)
        return labels

    def _cluster(self, X):
        # 实现具体的聚类算法
        return labels

步骤 2: 在配置文件中使用

HabitatsSegmention:
  clustering_mode: two_step
  supervoxel:
    algorithm: my_clustering
    n_clusters: 50
    param1: value1

步骤 3: 运行生境分析

habit get-habitat --config config_with_custom_clustering.yaml

示例: 自定义谱聚类算法

import numpy as np
from sklearn.cluster import SpectralClustering
from habit.core.habit_analysis.clustering.base_clustering import BaseClusteringAlgorithm
from habit.core.habit_analysis.clustering.clustering_factory import ClusteringFactory

@ClusteringFactory.register("spectral")
class SpectralClusteringAlgorithm(BaseClusteringAlgorithm):
    def __init__(self, n_clusters=3, random_state=None, **kwargs):
        super().__init__(n_clusters=n_clusters, random_state=random_state, **kwargs)
        self.gamma = kwargs.get('gamma', 1.0)
        self.n_neighbors = kwargs.get('n_neighbors', 10)

    def fit_predict(self, X, **kwargs):
        clustering = SpectralClustering(
            n_clusters=self.n_clusters,
            gamma=self.gamma,
            n_neighbors=self.n_neighbors,
            random_state=self.random_state
        )
        labels = clustering.fit_predict(X)
        return labels

自定义模型

步骤 1: 创建自定义模型

from habit.core.machine_learning.models.base import BaseModel
from habit.core.machine_learning.models.factory import ModelFactory

@ModelFactory.register("my_model")
class MyModel(BaseModel):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.param1 = kwargs.get('param1', default_value)
        self.model = None

    def fit(self, X, y, **kwargs):
        # 实现模型训练逻辑
        self.model = self._train(X, y)
        return self

    def predict(self, X, **kwargs):
        # 实现预测逻辑
        return self.model.predict(X)

    def predict_proba(self, X, **kwargs):
        # 实现概率预测逻辑
        return self.model.predict_proba(X)

    def _train(self, X, y):
        # 实现具体的训练算法
        return model

步骤 2: 在配置文件中使用

models:
  my_model:
    params:
      param1: value1

步骤 3: 运行机器学习

habit model --config config_with_custom_model.yaml

示例: 自定义神经网络模型

import numpy as np
from sklearn.neural_network import MLPClassifier
from habit.core.machine_learning.models.base import BaseModel
from habit.core.machine_learning.models.factory import ModelFactory

@ModelFactory.register("neural_network")
class NeuralNetworkModel(BaseModel):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.hidden_layer_sizes = kwargs.get('hidden_layer_sizes', (100,))
        self.activation = kwargs.get('activation', 'relu')
        self.solver = kwargs.get('solver', 'adam')
        self.max_iter = kwargs.get('max_iter', 200)
        self.random_state = kwargs.get('random_state', None)
        self.model = None

    def fit(self, X, y, **kwargs):
        self.model = MLPClassifier(
            hidden_layer_sizes=self.hidden_layer_sizes,
            activation=self.activation,
            solver=self.solver,
            max_iter=self.max_iter,
            random_state=self.random_state
        )
        self.model.fit(X, y)
        return self

    def predict(self, X, **kwargs):
        return self.model.predict(X)

    def predict_proba(self, X, **kwargs):
        return self.model.predict_proba(X)

自定义特征选择器

步骤 1: 创建自定义特征选择器

from sklearn.base import BaseEstimator, TransformerMixin
from habit.core.machine_learning.feature_selectors.selector_registry import register_selector

@register_selector('my_selector')
class MyFeatureSelector(BaseEstimator, TransformerMixin):
    def __init__(self, param1=default_value, param2=default_value):
        self.param1 = param1
        self.param2 = param2
        self.selected_features_ = None

    def fit(self, X, y=None):
        # 实现特征选择逻辑
        self.selected_features_ = self._select_features(X, y)
        return self

    def transform(self, X):
        # 实现特征转换逻辑
        return X[:, self.selected_features_]

    def _select_features(self, X, y):
        # 实现具体的特征选择算法
        return selected_indices

步骤 2: 在配置文件中使用

feature_selection_methods:
  - method: my_selector
    params:
      param1: value1
      param2: value2

步骤 3: 运行机器学习

habit model --config config_with_custom_selector.yaml

示例: 自定义互信息特征选择器

import numpy as np
from sklearn.feature_selection import mutual_info_classif
from sklearn.base import BaseEstimator, TransformerMixin
from habit.core.machine_learning.feature_selectors.selector_registry import register_selector

@register_selector('mutual_info')
class MutualInfoSelector(BaseEstimator, TransformerMixin):
    def __init__(self, k_features=10, random_state=None):
        self.k_features = k_features
        self.random_state = random_state
        self.selected_features_ = None
        self.scores_ = None

    def fit(self, X, y):
        scores = mutual_info_classif(X, y, random_state=self.random_state)
        self.scores_ = scores
        self.selected_features_ = np.argsort(scores)[-self.k_features:]
        return self

    def transform(self, X):
        return X[:, self.selected_features_]

最佳实践

1. 命名规范

  • 使用清晰的、描述性的名称

  • 使用小写字母和下划线

  • 避免使用缩写

示例:

# 好的命名
@PreprocessorFactory.register("gaussian_filter")
@register_feature_extractor('local_contrast')
@ClusteringFactory.register("spectral")

# 不好的命名
@PreprocessorFactory.register("gf")
@register_feature_extractor('lc')
@ClusteringFactory.register("spec")

2. 参数验证

对输入参数进行验证,确保参数的有效性。

示例:

def __init__(self, sigma=1.0, **kwargs):
    super().__init__(**kwargs)
    if sigma <= 0:
        raise ValueError("sigma must be positive")
    self.sigma = sigma

3. 文档字符串

为自定义组件提供清晰的文档字符串。

示例:

@PreprocessorFactory.register("gaussian_filter")
class GaussianFilterPreprocessor(BasePreprocessor):
    """
    高斯滤波预处理器。

    对图像应用高斯滤波,平滑图像并减少噪声。

    参数
    ----------
    sigma : float, default=1.0
        高斯核的标准差。值越大,平滑效果越强。
    order : int, default=0
        高斯滤波的阶数。0 表示平滑,1 表示一阶导数,2 表示二阶导数。

    注意事项
    ----------
    - 高斯滤波会模糊图像细节
    - 较大的 sigma 值会导致更强的平滑效果
    """

    def __init__(self, keys, allow_missing_keys=False, **kwargs):
        super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
        self.sigma = kwargs.get('sigma', 1.0)
        self.order = kwargs.get('order', 0)

4. 错误处理

提供清晰的错误信息,便于调试。

示例:

def __call__(self, data):
    self._check_keys(data)
    for key in self.keys:
        try:
            data[key] = self._process_item(data[key])
        except Exception as e:
            raise RuntimeError(f"Failed to process {key}: {str(e)}")
    return data

5. 测试

为自定义组件编写测试,确保功能的正确性。

示例:

import unittest
import numpy as np

class TestGaussianFilterPreprocessor(unittest.TestCase):
    def setUp(self):
        self.preprocessor = GaussianFilterPreprocessor(
            keys=['image'],
            sigma=1.0
        )

    def test_gaussian_filter(self):
        data = {'image': np.random.random((10, 10, 10))}
        result = self.preprocessor(data)
        self.assertIn('image', result)
        self.assertEqual(result['image'].shape, (10, 10, 10))

if __name__ == '__main__':
    unittest.main()

常见问题

Q1: 如何调试自定义组件?

A: 可以使用以下方法: 1. 使用 debug 模式启用详细日志 2. 在代码中添加 print 语句 3. 使用 Python 调试器(pdb) 4. 编写单元测试

Q2: 如何分享自定义组件?

A: 可以通过以下方式分享: 1. 将代码分享给其他研究者 2. 创建 GitHub 仓库 3. 提交到 HABIT 项目 4. 编写文档和示例

Q3: 如何优化自定义组件的性能?

A: 可以尝试以下方法: 1. 使用向量化操作 2. 使用并行计算 3. 使用 C/C++ 扩展 4. 优化算法

Q4: 如何确保自定义组件的正确性?

A: 可以通过以下方法验证: 1. 编写单元测试 2. 与已知结果对比 3. 使用可视化验证 4. 进行交叉验证

Q5: 如何处理自定义组件的依赖?

A: 可以通过以下方式处理: 1. 在文档中说明依赖 2. 提供安装说明 3. 使用虚拟环境 4. 提供依赖文件(requirements.txt)

下一步

自定义扩展完成后,您可以: