自定义扩展指南
本节介绍如何自定义和扩展 HABIT 的各种组件,包括预处理器、特征提取器、聚类算法、模型等。
概述
HABIT 的设计理念之一是"无限扩展",通过工厂模式和注册机制,用户可以轻松添加自定义组件。
可扩展的组件:
预处理器: 添加自定义的图像预处理方法
特征提取器: 添加自定义的聚类特征提取方法
聚类算法: 添加自定义的聚类算法
策略: 添加自定义的生境分割策略
模型: 添加自定义的机器学习模型
特征选择器: 添加自定义的特征选择方法
扩展机制:
HABIT 使用工厂模式和注册机制实现扩展:
工厂模式: 所有可扩展组件都使用工厂模式创建
注册机制: 通过装饰器注册自定义组件
统一接口: 所有自定义组件都遵循统一的接口规范
即插即用: 注册后即可在配置文件中使用
扩展原则
1. 遵循接口规范
所有自定义组件都必须继承相应的基类并实现必需的方法:
预处理器: 继承 BasePreprocessor,实现 __call__ 方法
特征提取器: 继承 BaseClusteringExtractor,实现 extract_features 方法
聚类算法: 继承 BaseClusteringAlgorithm,实现 fit_predict 方法
模型: 继承 BaseModel,实现 fit、predict、predict_proba 方法
2. 使用注册装饰器
使用相应的注册装饰器注册自定义组件:
预处理器: @PreprocessorFactory.register("name")
特征提取器: @register_feature_extractor('name')
聚类算法: @ClusteringFactory.register("name")
模型: @ModelFactory.register("name")
3. 提供清晰的文档
为自定义组件提供清晰的文档说明:
功能描述: 描述组件的功能和用途
参数说明: 说明所有参数的含义和默认值
使用示例: 提供使用示例
注意事项: 说明使用时的注意事项
4. 测试和验证
对自定义组件进行充分的测试和验证:
单元测试: 测试组件的基本功能
集成测试: 测试组件与其他组件的集成
性能测试: 测试组件的性能
验证: 验证组件的正确性
自定义预处理器
步骤 1: 创建自定义预处理器
from habit.core.preprocessing.preprocessor_factory import PreprocessorFactory
from habit.core.preprocessing.base_preprocessor import BasePreprocessor
@PreprocessorFactory.register("my_preprocessor")
class MyPreprocessor(BasePreprocessor):
def __init__(self, keys, allow_missing_keys=False, **kwargs):
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
self.param1 = kwargs.get('param1', default_value)
self.param2 = kwargs.get('param2', default_value)
def __call__(self, data):
self._check_keys(data)
for key in self.keys:
data[key] = self._process_item(data[key])
return data
def _process_item(self, item):
# 实现您的预处理逻辑
return processed_item
步骤 2: 在配置文件中使用
Preprocessing:
my_preprocessor:
images: [T1, T2]
param1: value1
param2: value2
步骤 3: 运行预处理
habit preprocess --config config_with_custom_preprocessor.yaml
示例: 自定义高斯滤波预处理器
import numpy as np
from scipy.ndimage import gaussian_filter
from habit.core.preprocessing.preprocessor_factory import PreprocessorFactory
from habit.core.preprocessing.base_preprocessor import BasePreprocessor
@PreprocessorFactory.register("gaussian_filter")
class GaussianFilterPreprocessor(BasePreprocessor):
def __init__(self, keys, allow_missing_keys=False, **kwargs):
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
self.sigma = kwargs.get('sigma', 1.0)
self.order = kwargs.get('order', 0)
def __call__(self, data):
self._check_keys(data)
for key in self.keys:
data[key] = self._process_item(data[key])
return data
def _process_item(self, item):
return gaussian_filter(item, sigma=self.sigma, order=self.order)
自定义特征提取器
步骤 1: 创建自定义特征提取器
from habit.core.habit_analysis.extractors.base_extractor import BaseClusteringExtractor
from habit.core.habit_analysis.extractors.base_extractor import register_feature_extractor
@register_feature_extractor('my_feature_extractor')
class MyFeatureExtractor(BaseClusteringExtractor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.feature_names = ['feature1', 'feature2', 'feature3']
def extract_features(self, image_data, **kwargs):
# 实现特征提取逻辑
n_samples = image_data.shape[0]
features = np.random.random((n_samples, 3))
return features
步骤 2: 在配置文件中使用
FeatureConstruction:
voxel_level:
method: my_feature_extractor(raw(delay2), raw(delay3))
params:
param1: value1
步骤 3: 运行生境分析
habit get-habitat --config config_with_custom_extractor.yaml
示例: 自定义局部对比度特征提取器
import numpy as np
from habit.core.habit_analysis.extractors.base_extractor import BaseClusteringExtractor
from habit.core.habit_analysis.extractors.base_extractor import register_feature_extractor
@register_feature_extractor('local_contrast')
class LocalContrastExtractor(BaseClusteringExtractor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.radius = kwargs.get('radius', 3)
self.feature_names = ['local_contrast']
def extract_features(self, image_data, **kwargs):
n_samples = image_data.shape[0]
features = np.zeros((n_samples, 1))
for i in range(n_samples):
features[i, 0] = self._compute_local_contrast(image_data[i])
return features
def _compute_local_contrast(self, image):
local_mean = self._compute_local_mean(image)
local_contrast = np.abs(image - local_mean)
return local_contrast
def _compute_local_mean(self, image):
from scipy.ndimage import uniform_filter
return uniform_filter(image, size=self.radius * 2 + 1)
自定义聚类算法
步骤 1: 创建自定义聚类算法
from habit.core.habit_analysis.clustering.base_clustering import BaseClusteringAlgorithm
from habit.core.habit_analysis.clustering.clustering_factory import ClusteringFactory
@ClusteringFactory.register("my_clustering")
class MyClusteringAlgorithm(BaseClusteringAlgorithm):
def __init__(self, n_clusters=3, random_state=None, **kwargs):
super().__init__(n_clusters=n_clusters, random_state=random_state, **kwargs)
self.param1 = kwargs.get('param1', default_value)
def fit_predict(self, X, **kwargs):
# 实现聚类逻辑
labels = self._cluster(X)
return labels
def _cluster(self, X):
# 实现具体的聚类算法
return labels
步骤 2: 在配置文件中使用
HabitatsSegmention:
clustering_mode: two_step
supervoxel:
algorithm: my_clustering
n_clusters: 50
param1: value1
步骤 3: 运行生境分析
habit get-habitat --config config_with_custom_clustering.yaml
示例: 自定义谱聚类算法
import numpy as np
from sklearn.cluster import SpectralClustering
from habit.core.habit_analysis.clustering.base_clustering import BaseClusteringAlgorithm
from habit.core.habit_analysis.clustering.clustering_factory import ClusteringFactory
@ClusteringFactory.register("spectral")
class SpectralClusteringAlgorithm(BaseClusteringAlgorithm):
def __init__(self, n_clusters=3, random_state=None, **kwargs):
super().__init__(n_clusters=n_clusters, random_state=random_state, **kwargs)
self.gamma = kwargs.get('gamma', 1.0)
self.n_neighbors = kwargs.get('n_neighbors', 10)
def fit_predict(self, X, **kwargs):
clustering = SpectralClustering(
n_clusters=self.n_clusters,
gamma=self.gamma,
n_neighbors=self.n_neighbors,
random_state=self.random_state
)
labels = clustering.fit_predict(X)
return labels
自定义模型
步骤 1: 创建自定义模型
from habit.core.machine_learning.models.base import BaseModel
from habit.core.machine_learning.models.factory import ModelFactory
@ModelFactory.register("my_model")
class MyModel(BaseModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.param1 = kwargs.get('param1', default_value)
self.model = None
def fit(self, X, y, **kwargs):
# 实现模型训练逻辑
self.model = self._train(X, y)
return self
def predict(self, X, **kwargs):
# 实现预测逻辑
return self.model.predict(X)
def predict_proba(self, X, **kwargs):
# 实现概率预测逻辑
return self.model.predict_proba(X)
def _train(self, X, y):
# 实现具体的训练算法
return model
步骤 2: 在配置文件中使用
models:
my_model:
params:
param1: value1
步骤 3: 运行机器学习
habit model --config config_with_custom_model.yaml
示例: 自定义神经网络模型
import numpy as np
from sklearn.neural_network import MLPClassifier
from habit.core.machine_learning.models.base import BaseModel
from habit.core.machine_learning.models.factory import ModelFactory
@ModelFactory.register("neural_network")
class NeuralNetworkModel(BaseModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.hidden_layer_sizes = kwargs.get('hidden_layer_sizes', (100,))
self.activation = kwargs.get('activation', 'relu')
self.solver = kwargs.get('solver', 'adam')
self.max_iter = kwargs.get('max_iter', 200)
self.random_state = kwargs.get('random_state', None)
self.model = None
def fit(self, X, y, **kwargs):
self.model = MLPClassifier(
hidden_layer_sizes=self.hidden_layer_sizes,
activation=self.activation,
solver=self.solver,
max_iter=self.max_iter,
random_state=self.random_state
)
self.model.fit(X, y)
return self
def predict(self, X, **kwargs):
return self.model.predict(X)
def predict_proba(self, X, **kwargs):
return self.model.predict_proba(X)
自定义特征选择器
步骤 1: 创建自定义特征选择器
from sklearn.base import BaseEstimator, TransformerMixin
from habit.core.machine_learning.feature_selectors.selector_registry import register_selector
@register_selector('my_selector')
class MyFeatureSelector(BaseEstimator, TransformerMixin):
def __init__(self, param1=default_value, param2=default_value):
self.param1 = param1
self.param2 = param2
self.selected_features_ = None
def fit(self, X, y=None):
# 实现特征选择逻辑
self.selected_features_ = self._select_features(X, y)
return self
def transform(self, X):
# 实现特征转换逻辑
return X[:, self.selected_features_]
def _select_features(self, X, y):
# 实现具体的特征选择算法
return selected_indices
步骤 2: 在配置文件中使用
feature_selection_methods:
- method: my_selector
params:
param1: value1
param2: value2
步骤 3: 运行机器学习
habit model --config config_with_custom_selector.yaml
示例: 自定义互信息特征选择器
import numpy as np
from sklearn.feature_selection import mutual_info_classif
from sklearn.base import BaseEstimator, TransformerMixin
from habit.core.machine_learning.feature_selectors.selector_registry import register_selector
@register_selector('mutual_info')
class MutualInfoSelector(BaseEstimator, TransformerMixin):
def __init__(self, k_features=10, random_state=None):
self.k_features = k_features
self.random_state = random_state
self.selected_features_ = None
self.scores_ = None
def fit(self, X, y):
scores = mutual_info_classif(X, y, random_state=self.random_state)
self.scores_ = scores
self.selected_features_ = np.argsort(scores)[-self.k_features:]
return self
def transform(self, X):
return X[:, self.selected_features_]
最佳实践
1. 命名规范
使用清晰的、描述性的名称
使用小写字母和下划线
避免使用缩写
示例:
# 好的命名
@PreprocessorFactory.register("gaussian_filter")
@register_feature_extractor('local_contrast')
@ClusteringFactory.register("spectral")
# 不好的命名
@PreprocessorFactory.register("gf")
@register_feature_extractor('lc')
@ClusteringFactory.register("spec")
2. 参数验证
对输入参数进行验证,确保参数的有效性。
示例:
def __init__(self, sigma=1.0, **kwargs):
super().__init__(**kwargs)
if sigma <= 0:
raise ValueError("sigma must be positive")
self.sigma = sigma
3. 文档字符串
为自定义组件提供清晰的文档字符串。
示例:
@PreprocessorFactory.register("gaussian_filter")
class GaussianFilterPreprocessor(BasePreprocessor):
"""
高斯滤波预处理器。
对图像应用高斯滤波,平滑图像并减少噪声。
参数
----------
sigma : float, default=1.0
高斯核的标准差。值越大,平滑效果越强。
order : int, default=0
高斯滤波的阶数。0 表示平滑,1 表示一阶导数,2 表示二阶导数。
注意事项
----------
- 高斯滤波会模糊图像细节
- 较大的 sigma 值会导致更强的平滑效果
"""
def __init__(self, keys, allow_missing_keys=False, **kwargs):
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
self.sigma = kwargs.get('sigma', 1.0)
self.order = kwargs.get('order', 0)
4. 错误处理
提供清晰的错误信息,便于调试。
示例:
def __call__(self, data):
self._check_keys(data)
for key in self.keys:
try:
data[key] = self._process_item(data[key])
except Exception as e:
raise RuntimeError(f"Failed to process {key}: {str(e)}")
return data
5. 测试
为自定义组件编写测试,确保功能的正确性。
示例:
import unittest
import numpy as np
class TestGaussianFilterPreprocessor(unittest.TestCase):
def setUp(self):
self.preprocessor = GaussianFilterPreprocessor(
keys=['image'],
sigma=1.0
)
def test_gaussian_filter(self):
data = {'image': np.random.random((10, 10, 10))}
result = self.preprocessor(data)
self.assertIn('image', result)
self.assertEqual(result['image'].shape, (10, 10, 10))
if __name__ == '__main__':
unittest.main()
常见问题
Q1: 如何调试自定义组件?
A: 可以使用以下方法: 1. 使用 debug 模式启用详细日志 2. 在代码中添加 print 语句 3. 使用 Python 调试器(pdb) 4. 编写单元测试
Q2: 如何分享自定义组件?
A: 可以通过以下方式分享: 1. 将代码分享给其他研究者 2. 创建 GitHub 仓库 3. 提交到 HABIT 项目 4. 编写文档和示例
Q3: 如何优化自定义组件的性能?
A: 可以尝试以下方法: 1. 使用向量化操作 2. 使用并行计算 3. 使用 C/C++ 扩展 4. 优化算法
Q4: 如何确保自定义组件的正确性?
A: 可以通过以下方法验证: 1. 编写单元测试 2. 与已知结果对比 3. 使用可视化验证 4. 进行交叉验证
Q5: 如何处理自定义组件的依赖?
A: 可以通过以下方式处理: 1. 在文档中说明依赖 2. 提供安装说明 3. 使用虚拟环境 4. 提供依赖文件(requirements.txt)
下一步
自定义扩展完成后,您可以:
配置参考: 了解配置文件的详细说明
HABIT 包的设计哲学和理念: 了解 HABIT 的设计哲学
CLI 参考文档: 了解 CLI 命令的详细说明