"""
Model Factory
Factory class for creating model instances
"""
from typing import Dict, Any, Optional, List, Union, Tuple
import importlib
import os
import sys
from abc import ABC, abstractmethod
import numpy as np
import pandas as pd
from .base import BaseModel
[文档]
class ModelFactory:
"""Factory class for creating model instances"""
_registry = {}
[文档]
@classmethod
def register(cls, name: str):
"""
Register a model class
Args:
name: Model name
Returns:
Decorator function
"""
def decorator(model_class):
cls._registry[name] = model_class
return model_class
return decorator
[文档]
@classmethod
def create_model(cls, model_name: str, config: Dict[str, Any] = None) -> BaseModel:
"""
Create a model instance
Args:
model_name: Name of model to create
config: Configuration dictionary
Returns:
BaseModel: Model instance
Raises:
ValueError: If model name is not registered
"""
if model_name not in cls._registry:
# Try to import the model module
try:
# Discover and import all model modules
cls._discover_models()
except ImportError:
pass
# Check registry again after import attempt
if model_name not in cls._registry:
raise ValueError(f"Model '{model_name}' not registered. Available models: {list(cls._registry.keys())}")
# Create instance with config or empty dict
config = config or {}
return cls._registry[model_name](config)
[文档]
@classmethod
def get_available_models(cls) -> List[str]:
"""
Get list of available model names
Returns:
List[str]: List of registered model names
"""
# Try to import all model modules from models directory
cls._discover_models()
return list(cls._registry.keys())
@classmethod
def _discover_models(cls) -> None:
"""Dynamically discover and import all Python modules in the models directory"""
# Get the models directory path
models_dir = os.path.dirname(os.path.abspath(__file__))
# Import all .py files (excluding __init__.py and factory itself)
for filename in os.listdir(models_dir):
if filename.endswith('.py') and not filename.startswith('__') and filename != 'factory.py':
# Remove .py extension
module_name = filename[:-3]
try:
# Use relative import
importlib.import_module(f".{module_name}", package="habit.core.machine_learning.models")
print(f"Successfully imported model module: {module_name}")
except ImportError as e:
print(f"Warning: Failed to import {module_name}: {e}")