border_Inspection/server/init_algorithms.py
2025-08-06 14:45:07 +08:00

122 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import json
from pathlib import Path
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from server.core.database import get_db
from server.models.algorithm import Algorithm
from models.base import Base
import os
from core.database import engine
# 创建数据库表
Base.metadata.create_all(bind=engine)
def init_algorithms():
"""初始化算法数据与algorithm目录保持一致"""
# 算法配置映射
algorithm_configs = {
"berthing": {
"name": "靠泊检测",
"description": "检测船舶靠泊过程中的关键行为",
"version": "1.0.0",
"detection_classes": ["person", "ship"],
"input_size": "640x640",
"tags": ["靠泊", "船舶", "安全"]
},
"unberthing": {
"name": "离泊检测",
"description": "检测船舶离泊过程中的关键行为",
"version": "1.0.0",
"detection_classes": ["person", "ship"],
"input_size": "640x640",
"tags": ["离泊", "船舶", "安全"]
},
"board_ship": {
"name": "登轮检测",
"description": "检测人员登轮行为",
"version": "1.0.0",
"detection_classes": ["person"],
"input_size": "640x640",
"tags": ["登轮", "人员", "安全"]
},
"leave_ship": {
"name": "离轮检测",
"description": "检测人员离轮行为",
"version": "1.0.0",
"detection_classes": ["person"],
"input_size": "640x640",
"tags": ["离轮", "人员", "安全"]
},
"bullet_frame": {
"name": "弹窗检测",
"description": "检测系统弹窗和异常界面",
"version": "1.0.0",
"detection_classes": ["window", "alert"],
"input_size": "640x640",
"tags": ["弹窗", "界面", "异常"]
}
}
db = next(get_db())
try:
# 获取algorithm目录的绝对路径
algorithm_dir = Path(__file__).parent.parent / "algorithm"
for algorithm_folder, config in algorithm_configs.items():
folder_path = algorithm_dir / algorithm_folder
if folder_path.exists():
# 检查模型文件
weights_path = folder_path / "weights" / "best.pt"
model_path = f"algorithm/{algorithm_folder}/weights/best.pt" if weights_path.exists() else None
# 检查是否已存在该算法
existing_algorithm = db.query(Algorithm).filter(
Algorithm.name == config["name"]
).first()
if existing_algorithm:
# 更新现有算法
existing_algorithm.model_path = model_path
existing_algorithm.detection_classes = json.dumps(config["detection_classes"])
existing_algorithm.input_size = config["input_size"]
existing_algorithm.tags = json.dumps(config["tags"])
existing_algorithm.status = "active" if model_path else "inactive"
print(f"更新算法: {config['name']}")
else:
# 创建新算法
algorithm = Algorithm(
name=config["name"],
description=config["description"],
version=config["version"],
model_path=model_path,
detection_classes=json.dumps(config["detection_classes"]),
input_size=config["input_size"],
tags=json.dumps(config["tags"]),
status="active" if model_path else "inactive",
accuracy=0.0, # 默认值
inference_time=0.0, # 默认值
is_enabled=True,
creator="system"
)
db.add(algorithm)
print(f"创建算法: {config['name']}")
db.commit()
print("算法数据初始化完成!")
except Exception as e:
db.rollback()
print(f"初始化算法数据失败: {e}")
raise
finally:
db.close()
if __name__ == "__main__":
init_algorithms()