""" 服务启动自动初始化模块 提供统一的服务启动初始化框架,支持: 1. 顺序初始化:按依赖顺序执行初始化任务 2. 并发初始化:无依赖的任务可并发执行 3. 失败重试:初始化失败可配置重试策略 4. 健康检查:初始化完成后进行健康检查 """ import asyncio import logging from abc import ABC, abstractmethod from typing import List, Optional, Callable, Any, Dict from dataclasses import dataclass, field from enum import Enum import traceback logger = logging.getLogger(__name__) class InitStatus(Enum): """初始化状态""" PENDING = "pending" RUNNING = "running" SUCCESS = "success" FAILED = "failed" SKIPPED = "skipped" @dataclass class InitResult: """初始化结果""" name: str status: InitStatus message: str = "" error: Optional[Exception] = None duration_ms: float = 0.0 def __str__(self): status_emoji = { InitStatus.SUCCESS: "✓", InitStatus.FAILED: "✗", InitStatus.SKIPPED: "○", InitStatus.RUNNING: "⟳", }.get(self.status, "?") return f"{status_emoji} {self.name}: {self.message}" @dataclass class InitTask: """初始化任务配置""" name: str func: Callable dependencies: List[str] = field(default_factory=list) enabled: bool = True retry_times: int = 0 retry_delay: float = 1.0 critical: bool = True # 失败是否阻止服务启动 def __post_init__(self): if not self.name: raise ValueError("Task name cannot be empty") class AppInitializer: """ 应用初始化管理器 使用示例: initializer = AppInitializer() # 添加初始化任务 @initializer.task("init_database") async def init_database(): # 数据库初始化逻辑 pass @initializer.task("init_cache", dependencies=["init_database"]) async def init_cache(): # 缓存初始化逻辑(依赖数据库先初始化) pass # 执行初始化 results = await initializer.initialize() """ def __init__(self, app_name: str = "FastAPI App"): self.app_name = app_name self.tasks: Dict[str, InitTask] = {} self.results: List[InitResult] = [] self._before_hooks: List[Callable] = [] self._after_hooks: List[Callable] = [] def task( self, name: str, dependencies: Optional[List[str]] = None, enabled: bool = True, retry_times: int = 0, retry_delay: float = 1.0, critical: bool = True ) -> Callable: """ 装饰器:注册初始化任务 Args: name: 任务名称 dependencies: 依赖的任务名称列表 enabled: 是否启用 retry_times: 失败重试次数 retry_delay: 重试延迟(秒) critical: 是否为关键任务(失败则阻止启动) """ def decorator(func: Callable) -> Callable: self.tasks[name] = InitTask( name=name, func=func, dependencies=dependencies or [], enabled=enabled, retry_times=retry_times, retry_delay=retry_delay, critical=critical ) return func return decorator def add_task( self, name: str, func: Callable, dependencies: Optional[List[str]] = None, enabled: bool = True, retry_times: int = 0, retry_delay: float = 1.0, critical: bool = True ) -> None: """手动添加初始化任务""" self.tasks[name] = InitTask( name=name, func=func, dependencies=dependencies or [], enabled=enabled, retry_times=retry_times, retry_delay=retry_delay, critical=critical ) def before_start(self, func: Callable) -> Callable: """装饰器:添加初始化前执行的钩子""" self._before_hooks.append(func) return func def after_start(self, func: Callable) -> Callable: """装饰器:添加初始化后执行的钩子""" self._after_hooks.append(func) return func async def initialize(self) -> List[InitResult]: """ 执行所有初始化任务 Returns: List[InitResult]: 初始化结果列表 """ import time start_time = time.time() logger.info(f"{'='*60}") logger.info(f"开始初始化 {self.app_name}") logger.info(f"{'='*60}") # 执行前置钩子 for hook in self._before_hooks: try: if asyncio.iscoroutinefunction(hook): await hook() else: hook() except Exception as e: logger.error(f"前置钩子执行失败: {e}") self.results = [] executed = set() failed_critical = False # 按拓扑顺序执行任务 while len(executed) < len(self.tasks): # 找出所有可以执行的任务(依赖已满足或未启用) ready_tasks = [ task for name, task in self.tasks.items() if name not in executed and task.enabled and all(dep in executed or dep not in self.tasks for dep in task.dependencies) ] if not ready_tasks: # 检查是否有未执行的任务 remaining = [name for name in self.tasks if name not in executed and self.tasks[name].enabled] if remaining: logger.error(f"存在循环依赖或未满足的依赖: {remaining}") break # 并发执行所有就绪的任务 results = await asyncio.gather( *[self._execute_task(task) for task in ready_tasks], return_exceptions=True ) for result in results: if isinstance(result, Exception): logger.error(f"任务执行异常: {result}") continue self.results.append(result) executed.add(result.name) if result.status == InitStatus.FAILED and self.tasks[result.name].critical: failed_critical = True logger.error(f"关键任务 {result.name} 初始化失败,停止后续初始化") # 执行后置钩子 if not failed_critical: for hook in self._after_hooks: try: if asyncio.iscoroutinefunction(hook): await hook() else: hook() except Exception as e: logger.error(f"后置钩子执行失败: {e}") # 打印摘要 duration = time.time() - start_time self._print_summary(duration, failed_critical) return self.results async def _execute_task(self, task: InitTask) -> InitResult: """执行单个初始化任务""" import time start_time = time.time() result = InitResult(name=task.name, status=InitStatus.PENDING) for attempt in range(task.retry_times + 1): try: result.status = InitStatus.RUNNING logger.info(f"正在执行: {task.name}" + (f" (重试 {attempt}/{task.retry_times})" if attempt > 0 else "")) # 执行任务 if asyncio.iscoroutinefunction(task.func): await task.func() else: task.func() result.status = InitStatus.SUCCESS result.message = "初始化成功" break except Exception as e: if attempt < task.retry_times: logger.warning(f"{task.name} 失败,{task.retry_delay}秒后重试: {e}") await asyncio.sleep(task.retry_delay) else: result.status = InitStatus.FAILED result.message = f"初始化失败: {str(e)}" result.error = e logger.error(f"{task.name} 失败: {e}\n{traceback.format_exc()}") result.duration_ms = (time.time() - start_time) * 1000 return result def _print_summary(self, duration: float, failed_critical: bool) -> None: """打印初始化摘要""" logger.info(f"{'='*60}") logger.info(f"初始化完成 (耗时: {duration:.2f}秒)") logger.info(f"{'='*60}") # 统计 success_count = sum(1 for r in self.results if r.status == InitStatus.SUCCESS) failed_count = sum(1 for r in self.results if r.status == InitStatus.FAILED) skipped_count = sum(1 for r in self.results if r.status == InitStatus.SKIPPED) logger.info(f"总计: {len(self.results)} 个任务") logger.info(f" 成功: {success_count}") logger.info(f" 失败: {failed_count}") logger.info(f" 跳过: {skipped_count}") # 详细结果 if self.results: logger.info(f"\n详细结果:") for result in self.results: logger.info(f" {result}") if failed_critical: logger.error(f"{'='*60}") logger.error(f"关键任务初始化失败,服务可能无法正常工作!") logger.error(f"{'='*60}") def get_results(self) -> List[InitResult]: """获取初始化结果""" return self.results def is_successful(self) -> bool: """检查是否所有关键任务都初始化成功""" for result in self.results: if result.status == InitStatus.FAILED: task = self.tasks.get(result.name) if task and task.critical: return False return True # 全局初始化器实例 _global_initializer: Optional[AppInitializer] = None def get_initializer() -> AppInitializer: """获取全局初始化器实例""" global _global_initializer if _global_initializer is None: _global_initializer = AppInitializer() return _global_initializer def set_initializer(initializer: AppInitializer) -> None: """设置全局初始化器实例""" global _global_initializer _global_initializer = initializer