122 lines
4.5 KiB
Python
122 lines
4.5 KiB
Python
import os
|
||
import sys
|
||
import json
|
||
from pathlib import Path
|
||
|
||
# 添加项目根目录到Python路径
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from server.core.database import get_db
|
||
from server.models.algorithm import Algorithm
|
||
from models.base import Base
|
||
import os
|
||
from core.database import engine
|
||
|
||
# 创建数据库表
|
||
Base.metadata.create_all(bind=engine)
|
||
|
||
def init_algorithms():
|
||
"""初始化算法数据,与algorithm目录保持一致"""
|
||
|
||
# 算法配置映射
|
||
algorithm_configs = {
|
||
"berthing": {
|
||
"name": "靠泊检测",
|
||
"description": "检测船舶靠泊过程中的关键行为",
|
||
"version": "1.0.0",
|
||
"detection_classes": ["person", "ship"],
|
||
"input_size": "640x640",
|
||
"tags": ["靠泊", "船舶", "安全"]
|
||
},
|
||
"unberthing": {
|
||
"name": "离泊检测",
|
||
"description": "检测船舶离泊过程中的关键行为",
|
||
"version": "1.0.0",
|
||
"detection_classes": ["person", "ship"],
|
||
"input_size": "640x640",
|
||
"tags": ["离泊", "船舶", "安全"]
|
||
},
|
||
"board_ship": {
|
||
"name": "登轮检测",
|
||
"description": "检测人员登轮行为",
|
||
"version": "1.0.0",
|
||
"detection_classes": ["person"],
|
||
"input_size": "640x640",
|
||
"tags": ["登轮", "人员", "安全"]
|
||
},
|
||
"leave_ship": {
|
||
"name": "离轮检测",
|
||
"description": "检测人员离轮行为",
|
||
"version": "1.0.0",
|
||
"detection_classes": ["person"],
|
||
"input_size": "640x640",
|
||
"tags": ["离轮", "人员", "安全"]
|
||
},
|
||
"bullet_frame": {
|
||
"name": "弹窗检测",
|
||
"description": "检测系统弹窗和异常界面",
|
||
"version": "1.0.0",
|
||
"detection_classes": ["window", "alert"],
|
||
"input_size": "640x640",
|
||
"tags": ["弹窗", "界面", "异常"]
|
||
}
|
||
}
|
||
|
||
db = next(get_db())
|
||
|
||
try:
|
||
# 获取algorithm目录的绝对路径
|
||
algorithm_dir = Path(__file__).parent.parent / "algorithm"
|
||
|
||
for algorithm_folder, config in algorithm_configs.items():
|
||
folder_path = algorithm_dir / algorithm_folder
|
||
|
||
if folder_path.exists():
|
||
# 检查模型文件
|
||
weights_path = folder_path / "weights" / "best.pt"
|
||
model_path = f"algorithm/{algorithm_folder}/weights/best.pt" if weights_path.exists() else None
|
||
|
||
# 检查是否已存在该算法
|
||
existing_algorithm = db.query(Algorithm).filter(
|
||
Algorithm.name == config["name"]
|
||
).first()
|
||
|
||
if existing_algorithm:
|
||
# 更新现有算法
|
||
existing_algorithm.model_path = model_path
|
||
existing_algorithm.detection_classes = json.dumps(config["detection_classes"])
|
||
existing_algorithm.input_size = config["input_size"]
|
||
existing_algorithm.tags = json.dumps(config["tags"])
|
||
existing_algorithm.status = "active" if model_path else "inactive"
|
||
print(f"更新算法: {config['name']}")
|
||
else:
|
||
# 创建新算法
|
||
algorithm = Algorithm(
|
||
name=config["name"],
|
||
description=config["description"],
|
||
version=config["version"],
|
||
model_path=model_path,
|
||
detection_classes=json.dumps(config["detection_classes"]),
|
||
input_size=config["input_size"],
|
||
tags=json.dumps(config["tags"]),
|
||
status="active" if model_path else "inactive",
|
||
accuracy=0.0, # 默认值
|
||
inference_time=0.0, # 默认值
|
||
is_enabled=True,
|
||
creator="system"
|
||
)
|
||
db.add(algorithm)
|
||
print(f"创建算法: {config['name']}")
|
||
|
||
db.commit()
|
||
print("算法数据初始化完成!")
|
||
|
||
except Exception as e:
|
||
db.rollback()
|
||
print(f"初始化算法数据失败: {e}")
|
||
raise
|
||
finally:
|
||
db.close()
|
||
|
||
if __name__ == "__main__":
|
||
init_algorithms() |