import os import sys import json from pathlib import Path # 添加项目根目录到Python路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from server.core.database import get_db from server.models.algorithm import Algorithm from models.base import Base import os from core.database import engine # 创建数据库表 Base.metadata.create_all(bind=engine) def init_algorithms(): """初始化算法数据,与algorithm目录保持一致""" # 算法配置映射 algorithm_configs = { "berthing": { "name": "靠泊检测", "description": "检测船舶靠泊过程中的关键行为", "version": "1.0.0", "detection_classes": ["person", "ship"], "input_size": "640x640", "tags": ["靠泊", "船舶", "安全"] }, "unberthing": { "name": "离泊检测", "description": "检测船舶离泊过程中的关键行为", "version": "1.0.0", "detection_classes": ["person", "ship"], "input_size": "640x640", "tags": ["离泊", "船舶", "安全"] }, "board_ship": { "name": "登轮检测", "description": "检测人员登轮行为", "version": "1.0.0", "detection_classes": ["person"], "input_size": "640x640", "tags": ["登轮", "人员", "安全"] }, "leave_ship": { "name": "离轮检测", "description": "检测人员离轮行为", "version": "1.0.0", "detection_classes": ["person"], "input_size": "640x640", "tags": ["离轮", "人员", "安全"] }, "bullet_frame": { "name": "弹窗检测", "description": "检测系统弹窗和异常界面", "version": "1.0.0", "detection_classes": ["window", "alert"], "input_size": "640x640", "tags": ["弹窗", "界面", "异常"] } } db = next(get_db()) try: # 获取algorithm目录的绝对路径 algorithm_dir = Path(__file__).parent.parent / "algorithm" for algorithm_folder, config in algorithm_configs.items(): folder_path = algorithm_dir / algorithm_folder if folder_path.exists(): # 检查模型文件 weights_path = folder_path / "weights" / "best.pt" model_path = f"algorithm/{algorithm_folder}/weights/best.pt" if weights_path.exists() else None # 检查是否已存在该算法 existing_algorithm = db.query(Algorithm).filter( Algorithm.name == config["name"] ).first() if existing_algorithm: # 更新现有算法 existing_algorithm.model_path = model_path existing_algorithm.detection_classes = json.dumps(config["detection_classes"]) existing_algorithm.input_size = config["input_size"] existing_algorithm.tags = json.dumps(config["tags"]) existing_algorithm.status = "active" if model_path else "inactive" print(f"更新算法: {config['name']}") else: # 创建新算法 algorithm = Algorithm( name=config["name"], description=config["description"], version=config["version"], model_path=model_path, detection_classes=json.dumps(config["detection_classes"]), input_size=config["input_size"], tags=json.dumps(config["tags"]), status="active" if model_path else "inactive", accuracy=0.0, # 默认值 inference_time=0.0, # 默认值 is_enabled=True, creator="system" ) db.add(algorithm) print(f"创建算法: {config['name']}") db.commit() print("算法数据初始化完成!") except Exception as e: db.rollback() print(f"初始化算法数据失败: {e}") raise finally: db.close() if __name__ == "__main__": init_algorithms()