border_Inspection/server/init_algorithms.py

122 lines
4.5 KiB
Python
Raw Normal View History

2025-08-06 14:45:07 +08:00
import os
import sys
import json
from pathlib import Path
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from server.core.database import get_db
from server.models.algorithm import Algorithm
from models.base import Base
import os
from core.database import engine
# 创建数据库表
Base.metadata.create_all(bind=engine)
def init_algorithms():
"""初始化算法数据与algorithm目录保持一致"""
# 算法配置映射
algorithm_configs = {
"berthing": {
"name": "靠泊检测",
"description": "检测船舶靠泊过程中的关键行为",
"version": "1.0.0",
"detection_classes": ["person", "ship"],
"input_size": "640x640",
"tags": ["靠泊", "船舶", "安全"]
},
"unberthing": {
"name": "离泊检测",
"description": "检测船舶离泊过程中的关键行为",
"version": "1.0.0",
"detection_classes": ["person", "ship"],
"input_size": "640x640",
"tags": ["离泊", "船舶", "安全"]
},
"board_ship": {
"name": "登轮检测",
"description": "检测人员登轮行为",
"version": "1.0.0",
"detection_classes": ["person"],
"input_size": "640x640",
"tags": ["登轮", "人员", "安全"]
},
"leave_ship": {
"name": "离轮检测",
"description": "检测人员离轮行为",
"version": "1.0.0",
"detection_classes": ["person"],
"input_size": "640x640",
"tags": ["离轮", "人员", "安全"]
},
"bullet_frame": {
"name": "弹窗检测",
"description": "检测系统弹窗和异常界面",
"version": "1.0.0",
"detection_classes": ["window", "alert"],
"input_size": "640x640",
"tags": ["弹窗", "界面", "异常"]
}
}
db = next(get_db())
try:
# 获取algorithm目录的绝对路径
algorithm_dir = Path(__file__).parent.parent / "algorithm"
for algorithm_folder, config in algorithm_configs.items():
folder_path = algorithm_dir / algorithm_folder
if folder_path.exists():
# 检查模型文件
weights_path = folder_path / "weights" / "best.pt"
model_path = f"algorithm/{algorithm_folder}/weights/best.pt" if weights_path.exists() else None
# 检查是否已存在该算法
existing_algorithm = db.query(Algorithm).filter(
Algorithm.name == config["name"]
).first()
if existing_algorithm:
# 更新现有算法
existing_algorithm.model_path = model_path
existing_algorithm.detection_classes = json.dumps(config["detection_classes"])
existing_algorithm.input_size = config["input_size"]
existing_algorithm.tags = json.dumps(config["tags"])
existing_algorithm.status = "active" if model_path else "inactive"
print(f"更新算法: {config['name']}")
else:
# 创建新算法
algorithm = Algorithm(
name=config["name"],
description=config["description"],
version=config["version"],
model_path=model_path,
detection_classes=json.dumps(config["detection_classes"]),
input_size=config["input_size"],
tags=json.dumps(config["tags"]),
status="active" if model_path else "inactive",
accuracy=0.0, # 默认值
inference_time=0.0, # 默认值
is_enabled=True,
creator="system"
)
db.add(algorithm)
print(f"创建算法: {config['name']}")
db.commit()
print("算法数据初始化完成!")
except Exception as e:
db.rollback()
print(f"初始化算法数据失败: {e}")
raise
finally:
db.close()
if __name__ == "__main__":
init_algorithms()