122 lines
4.5 KiB
Python
122 lines
4.5 KiB
Python
![]() |
import os
|
|||
|
import sys
|
|||
|
import json
|
|||
|
from pathlib import Path
|
|||
|
|
|||
|
# 添加项目根目录到Python路径
|
|||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
|
|||
|
from server.core.database import get_db
|
|||
|
from server.models.algorithm import Algorithm
|
|||
|
from models.base import Base
|
|||
|
import os
|
|||
|
from core.database import engine
|
|||
|
|
|||
|
# 创建数据库表
|
|||
|
Base.metadata.create_all(bind=engine)
|
|||
|
|
|||
|
def init_algorithms():
|
|||
|
"""初始化算法数据,与algorithm目录保持一致"""
|
|||
|
|
|||
|
# 算法配置映射
|
|||
|
algorithm_configs = {
|
|||
|
"berthing": {
|
|||
|
"name": "靠泊检测",
|
|||
|
"description": "检测船舶靠泊过程中的关键行为",
|
|||
|
"version": "1.0.0",
|
|||
|
"detection_classes": ["person", "ship"],
|
|||
|
"input_size": "640x640",
|
|||
|
"tags": ["靠泊", "船舶", "安全"]
|
|||
|
},
|
|||
|
"unberthing": {
|
|||
|
"name": "离泊检测",
|
|||
|
"description": "检测船舶离泊过程中的关键行为",
|
|||
|
"version": "1.0.0",
|
|||
|
"detection_classes": ["person", "ship"],
|
|||
|
"input_size": "640x640",
|
|||
|
"tags": ["离泊", "船舶", "安全"]
|
|||
|
},
|
|||
|
"board_ship": {
|
|||
|
"name": "登轮检测",
|
|||
|
"description": "检测人员登轮行为",
|
|||
|
"version": "1.0.0",
|
|||
|
"detection_classes": ["person"],
|
|||
|
"input_size": "640x640",
|
|||
|
"tags": ["登轮", "人员", "安全"]
|
|||
|
},
|
|||
|
"leave_ship": {
|
|||
|
"name": "离轮检测",
|
|||
|
"description": "检测人员离轮行为",
|
|||
|
"version": "1.0.0",
|
|||
|
"detection_classes": ["person"],
|
|||
|
"input_size": "640x640",
|
|||
|
"tags": ["离轮", "人员", "安全"]
|
|||
|
},
|
|||
|
"bullet_frame": {
|
|||
|
"name": "弹窗检测",
|
|||
|
"description": "检测系统弹窗和异常界面",
|
|||
|
"version": "1.0.0",
|
|||
|
"detection_classes": ["window", "alert"],
|
|||
|
"input_size": "640x640",
|
|||
|
"tags": ["弹窗", "界面", "异常"]
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
db = next(get_db())
|
|||
|
|
|||
|
try:
|
|||
|
# 获取algorithm目录的绝对路径
|
|||
|
algorithm_dir = Path(__file__).parent.parent / "algorithm"
|
|||
|
|
|||
|
for algorithm_folder, config in algorithm_configs.items():
|
|||
|
folder_path = algorithm_dir / algorithm_folder
|
|||
|
|
|||
|
if folder_path.exists():
|
|||
|
# 检查模型文件
|
|||
|
weights_path = folder_path / "weights" / "best.pt"
|
|||
|
model_path = f"algorithm/{algorithm_folder}/weights/best.pt" if weights_path.exists() else None
|
|||
|
|
|||
|
# 检查是否已存在该算法
|
|||
|
existing_algorithm = db.query(Algorithm).filter(
|
|||
|
Algorithm.name == config["name"]
|
|||
|
).first()
|
|||
|
|
|||
|
if existing_algorithm:
|
|||
|
# 更新现有算法
|
|||
|
existing_algorithm.model_path = model_path
|
|||
|
existing_algorithm.detection_classes = json.dumps(config["detection_classes"])
|
|||
|
existing_algorithm.input_size = config["input_size"]
|
|||
|
existing_algorithm.tags = json.dumps(config["tags"])
|
|||
|
existing_algorithm.status = "active" if model_path else "inactive"
|
|||
|
print(f"更新算法: {config['name']}")
|
|||
|
else:
|
|||
|
# 创建新算法
|
|||
|
algorithm = Algorithm(
|
|||
|
name=config["name"],
|
|||
|
description=config["description"],
|
|||
|
version=config["version"],
|
|||
|
model_path=model_path,
|
|||
|
detection_classes=json.dumps(config["detection_classes"]),
|
|||
|
input_size=config["input_size"],
|
|||
|
tags=json.dumps(config["tags"]),
|
|||
|
status="active" if model_path else "inactive",
|
|||
|
accuracy=0.0, # 默认值
|
|||
|
inference_time=0.0, # 默认值
|
|||
|
is_enabled=True,
|
|||
|
creator="system"
|
|||
|
)
|
|||
|
db.add(algorithm)
|
|||
|
print(f"创建算法: {config['name']}")
|
|||
|
|
|||
|
db.commit()
|
|||
|
print("算法数据初始化完成!")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
db.rollback()
|
|||
|
print(f"初始化算法数据失败: {e}")
|
|||
|
raise
|
|||
|
finally:
|
|||
|
db.close()
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
init_algorithms()
|