算法接口

This commit is contained in:
zlgecc 2025-08-06 14:45:07 +08:00
parent 2fdfd63370
commit b71a9b398d
14 changed files with 1035 additions and 578 deletions

View File

0
algorithm/__init__.py Normal file
View File

543
algorithm/detection.py Normal file
View File

@ -0,0 +1,543 @@
import cv2
import supervision as sv
from ultralytics import YOLO
import os
import time
import numpy as np
from tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont
import io
from typing import List, Tuple, Optional, Dict, Any
import json
class DetectionProcessor:
"""
目标检测处理器类
用于处理视频中的目标检测任务
"""
def __init__(self, model_path: str = "yolov11n.pt"):
"""
初始化检测处理器
Args:
model_path: YOLO模型路径
"""
self.model_path = model_path
self.model = None
self._load_model()
def _load_model(self):
"""加载YOLO模型"""
try:
print(f"加载检测模型: {self.model_path}")
self.model = YOLO(self.model_path)
except Exception as e:
print(f"模型加载失败: {e}")
raise
def _put_chinese_text(self, img, text, position, font_size=20, color=(0, 0, 255)):
"""在图像上添加中文文本"""
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img_pil)
try:
font = ImageFont.truetype("msyh.ttc", font_size)
except:
font = ImageFont.load_default()
draw.text(position, text, font=font, fill=color)
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
def process_video(
self,
input_video_path: str,
output_video_path: str,
confidence_threshold: float = 0.5,
classes: Optional[List[int]] = None,
show_live: bool = False,
save_annotated: bool = True,
warning_zones: Optional[List[List[Tuple[float, float]]]] = None,
) -> Dict[str, Any]:
"""
处理视频进行目标检测
Args:
input_video_path: 输入视频路径
output_video_path: 输出视频路径
confidence_threshold: 置信度阈值
classes: 要检测的类别ID列表
show_live: 是否实时显示
save_annotated: 是否保存标注视频
warning_zones: 警告区域多边形点列表
Returns:
处理结果字典
"""
if not os.path.exists(input_video_path):
raise FileNotFoundError(f"输入视频文件不存在: {input_video_path}")
# 初始化视频读取器
video_info = sv.VideoInfo.from_video_path(input_video_path)
cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
raise RuntimeError(f"无法打开视频文件: {input_video_path}")
# 初始化视频写入器
writer = None
if save_annotated:
output_dir = os.path.dirname(output_video_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
writer = cv2.VideoWriter(
output_video_path,
cv2.VideoWriter_fourcc(*"mp4v"),
video_info.fps,
(video_info.width, video_info.height),
)
# 初始化Supervision工具
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
# 处理统计
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_count = 0
processing_times = []
detection_results = []
warning_events = []
pbar = tqdm(total=total_frames, desc="处理视频帧")
while cap.isOpened():
start_time = time.time()
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# 目标检测
results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0]
detections = sv.Detections.from_ultralytics(results)
# 准备标注
labels = [
f"{results.names[class_id]} {confidence:.2f}"
for class_id, confidence in zip(detections.class_id, detections.confidence)
]
# 标注边界框
annotated_frame = box_annotator.annotate(scene=frame.copy(), detections=detections)
# 标注标签
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
# 处理警告区域
if warning_zones:
for i, zone in enumerate(warning_zones):
pts = [(int(x), int(y)) for x, y in zone]
color = (0, 255, 0) if i == 0 else (0, 0, 255)
cv2.polylines(annotated_frame, [np.array(pts)], isClosed=True, color=color, thickness=2)
# 检测目标是否在警告区域内
polygon = np.array(pts, np.int32)
for bbox in detections.xyxy:
x1, y1, x2, y2 = bbox
center_x = int((x1 + x2) / 2)
center_y = int((y1 + y2) / 2)
distance = cv2.pointPolygonTest(polygon, (center_x, center_y), False)
if distance >= 0:
warning_text = f"警告: 目标进入区域{i+1}!"
annotated_frame = self._put_chinese_text(
annotated_frame, warning_text, (50, 50 + i * 30), font_size=20, color=(255, 0, 0)
)
warning_events.append(
{"frame": frame_count, "zone": i + 1, "position": (center_x, center_y)}
)
# 记录检测结果
frame_detections = []
for i, (bbox, class_id, confidence) in enumerate(
zip(detections.xyxy, detections.class_id, detections.confidence)
):
frame_detections.append(
{
"bbox": bbox.tolist(),
"class_id": int(class_id),
"class_name": results.names[int(class_id)],
"confidence": float(confidence),
}
)
detection_results.append({"frame": frame_count, "detections": frame_detections})
# 计算处理时间
end_time = time.time()
processing_time = end_time - start_time
processing_times.append(processing_time)
# 实时显示
if show_live:
cv2.imshow("Detection", annotated_frame)
if cv2.waitKey(1) == 27:
break
# 保存帧
if save_annotated and writer:
writer.write(annotated_frame)
pbar.update(1)
# 清理资源
cap.release()
if writer:
writer.release()
if show_live:
cv2.destroyAllWindows()
pbar.close()
# 计算统计信息
avg_time = sum(processing_times) / len(processing_times) if processing_times else 0
fps = 1 / avg_time if avg_time > 0 else 0
return {
"success": True,
"total_frames": frame_count,
"avg_processing_time": avg_time,
"fps": fps,
"detection_results": detection_results,
"warning_events": warning_events,
"output_video_path": output_video_path if save_annotated else None,
}
def process_image(
self,
image_path: str,
confidence_threshold: float = 0.5,
classes: Optional[List[int]] = None,
warning_zones: Optional[List[List[Tuple[float, float]]]] = None,
) -> Dict[str, Any]:
"""
处理单张图像进行目标检测
Args:
image_path: 输入图像路径
confidence_threshold: 置信度阈值
classes: 要检测的类别ID列表
warning_zones: 警告区域多边形点列表
Returns:
检测结果字典
"""
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像文件不存在: {image_path}")
# 读取图像
frame = cv2.imread(image_path)
if frame is None:
raise RuntimeError(f"无法读取图像文件: {image_path}")
# 目标检测
results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0]
detections = sv.Detections.from_ultralytics(results)
# 准备标注
labels = [
f"{results.names[class_id]} {confidence:.2f}"
for class_id, confidence in zip(detections.class_id, detections.confidence)
]
# 标注边界框
annotated_frame = sv.BoxAnnotator().annotate(scene=frame.copy(), detections=detections)
# 标注标签
annotated_frame = sv.LabelAnnotator().annotate(scene=annotated_frame, detections=detections, labels=labels)
# 处理警告区域
warning_events = []
if warning_zones:
for i, zone in enumerate(warning_zones):
pts = [(int(x), int(y)) for x, y in zone]
color = (0, 255, 0) if i == 0 else (0, 0, 255)
cv2.polylines(annotated_frame, [np.array(pts)], isClosed=True, color=color, thickness=2)
polygon = np.array(pts, np.int32)
for bbox in detections.xyxy:
x1, y1, x2, y2 = bbox
center_x = int((x1 + x2) / 2)
center_y = int((y1 + y2) / 2)
distance = cv2.pointPolygonTest(polygon, (center_x, center_y), False)
if distance >= 0:
warning_text = f"警告: 目标进入区域{i+1}!"
annotated_frame = self._put_chinese_text(
annotated_frame, warning_text, (50, 50 + i * 30), font_size=20, color=(255, 0, 0)
)
warning_events.append({"zone": i + 1, "position": (center_x, center_y)})
# 记录检测结果
detection_results = []
for bbox, class_id, confidence in zip(detections.xyxy, detections.class_id, detections.confidence):
detection_results.append(
{
"bbox": bbox.tolist(),
"class_id": int(class_id),
"class_name": results.names[int(class_id)],
"confidence": float(confidence),
}
)
return {
"success": True,
"detection_results": detection_results,
"warning_events": warning_events,
"annotated_image": annotated_frame,
}
class SegmentationProcessor:
"""
图像分割处理器类
用于处理视频中的图像分割任务
"""
def __init__(self, model_path: str = "yolov8n-seg.pt"):
"""
初始化分割处理器
Args:
model_path: YOLO分割模型路径
"""
self.model_path = model_path
self.model = None
self._load_model()
def _load_model(self):
"""加载YOLO分割模型"""
try:
print(f"加载分割模型: {self.model_path}")
self.model = YOLO(self.model_path)
except Exception as e:
print(f"模型加载失败: {e}")
raise
def process_video(
self,
input_video_path: str,
output_video_path: str,
confidence_threshold: float = 0.5,
classes: Optional[List[int]] = None,
show_live: bool = False,
save_annotated: bool = True,
) -> Dict[str, Any]:
"""
处理视频进行图像分割
Args:
input_video_path: 输入视频路径
output_video_path: 输出视频路径
confidence_threshold: 置信度阈值
classes: 要检测的类别ID列表
show_live: 是否实时显示
save_annotated: 是否保存标注视频
Returns:
处理结果字典
"""
if not os.path.exists(input_video_path):
raise FileNotFoundError(f"输入视频文件不存在: {input_video_path}")
# 初始化视频读取器
video_info = sv.VideoInfo.from_video_path(input_video_path)
cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
raise RuntimeError(f"无法打开视频文件: {input_video_path}")
# 初始化视频写入器
writer = None
if save_annotated:
output_dir = os.path.dirname(output_video_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
writer = cv2.VideoWriter(
output_video_path,
cv2.VideoWriter_fourcc(*"mp4v"),
video_info.fps,
(video_info.width, video_info.height),
)
# 初始化Supervision工具
mask_annotator = sv.MaskAnnotator(color=sv.Color(r=0, g=255, b=0), opacity=0.5)
# 处理统计
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_count = 0
processing_times = []
segmentation_results = []
pbar = tqdm(total=total_frames, desc="处理视频帧")
while cap.isOpened():
start_time = time.time()
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# 图像分割
results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0]
detections = sv.Detections.from_ultralytics(results)
# 应用分割掩码
segmented_frame = mask_annotator.annotate(scene=frame.copy(), detections=detections)
# 记录分割结果
frame_segmentations = []
for i, (bbox, class_id, confidence) in enumerate(
zip(detections.xyxy, detections.class_id, detections.confidence)
):
frame_segmentations.append(
{
"bbox": bbox.tolist(),
"class_id": int(class_id),
"class_name": results.names[int(class_id)],
"confidence": float(confidence),
}
)
segmentation_results.append({"frame": frame_count, "segmentations": frame_segmentations})
# 计算处理时间
end_time = time.time()
processing_time = end_time - start_time
processing_times.append(processing_time)
# 实时显示
if show_live:
cv2.imshow("Segmentation", segmented_frame)
if cv2.waitKey(1) == 27:
break
# 保存帧
if save_annotated and writer:
writer.write(segmented_frame)
pbar.update(1)
# 清理资源
cap.release()
if writer:
writer.release()
if show_live:
cv2.destroyAllWindows()
pbar.close()
# 计算统计信息
avg_time = sum(processing_times) / len(processing_times) if processing_times else 0
fps = 1 / avg_time if avg_time > 0 else 0
return {
"success": True,
"total_frames": frame_count,
"avg_processing_time": avg_time,
"fps": fps,
"segmentation_results": segmentation_results,
"output_video_path": output_video_path if save_annotated else None,
}
def process_image(
self, image_path: str, confidence_threshold: float = 0.5, classes: Optional[List[int]] = None
) -> Dict[str, Any]:
"""
处理单张图像进行图像分割
Args:
image_path: 输入图像路径
confidence_threshold: 置信度阈值
classes: 要检测的类别ID列表
Returns:
分割结果字典
"""
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像文件不存在: {image_path}")
# 读取图像
frame = cv2.imread(image_path)
if frame is None:
raise RuntimeError(f"无法读取图像文件: {image_path}")
# 图像分割
results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0]
detections = sv.Detections.from_ultralytics(results)
# 应用分割掩码
mask_annotator = sv.MaskAnnotator(color=sv.Color(r=0, g=255, b=0), opacity=0.5)
segmented_frame = mask_annotator.annotate(scene=frame.copy(), detections=detections)
# 记录分割结果
segmentation_results = []
for bbox, class_id, confidence in zip(detections.xyxy, detections.class_id, detections.confidence):
segmentation_results.append(
{
"bbox": bbox.tolist(),
"class_id": int(class_id),
"class_name": results.names[int(class_id)],
"confidence": float(confidence),
}
)
return {"success": True, "segmentation_results": segmentation_results, "segmented_image": segmented_frame}
# 示例用法
if __name__ == "__main__":
# 目标检测示例
detector = DetectionProcessor(model_path="yolov11n.pt")
# 定义警告区域
warning_zones = [
[(1153.11, 273.86), (1146.09, 370.77), (1217.71, 352.51), (1220.52, 257.01)],
[(1140.47, 506.99), (1214.90, 504.19), (1212.09, 592.66), (1136.25, 594.07)],
]
# 处理视频
result = detector.process_video(
input_video_path="input_video.mp4",
output_video_path="output_detection.mp4",
confidence_threshold=0.5,
classes=[0], # 只检测人
warning_zones=warning_zones,
save_annotated=True,
)
print("检测结果:", result)
# 图像分割示例
segmenter = SegmentationProcessor(model_path="yolov8n-seg.pt")
# 处理视频
seg_result = segmenter.process_video(
input_video_path="input_video.mp4",
output_video_path="output_segmentation.mp4",
confidence_threshold=0.3,
classes=[0], # 只检测人
save_annotated=True,
)
print("分割结果:", seg_result)

BIN
server/border_inspection.db Normal file

Binary file not shown.

122
server/init_algorithms.py Normal file
View File

@ -0,0 +1,122 @@
import os
import sys
import json
from pathlib import Path
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from server.core.database import get_db
from server.models.algorithm import Algorithm
from models.base import Base
import os
from core.database import engine
# 创建数据库表
Base.metadata.create_all(bind=engine)
def init_algorithms():
"""初始化算法数据与algorithm目录保持一致"""
# 算法配置映射
algorithm_configs = {
"berthing": {
"name": "靠泊检测",
"description": "检测船舶靠泊过程中的关键行为",
"version": "1.0.0",
"detection_classes": ["person", "ship"],
"input_size": "640x640",
"tags": ["靠泊", "船舶", "安全"]
},
"unberthing": {
"name": "离泊检测",
"description": "检测船舶离泊过程中的关键行为",
"version": "1.0.0",
"detection_classes": ["person", "ship"],
"input_size": "640x640",
"tags": ["离泊", "船舶", "安全"]
},
"board_ship": {
"name": "登轮检测",
"description": "检测人员登轮行为",
"version": "1.0.0",
"detection_classes": ["person"],
"input_size": "640x640",
"tags": ["登轮", "人员", "安全"]
},
"leave_ship": {
"name": "离轮检测",
"description": "检测人员离轮行为",
"version": "1.0.0",
"detection_classes": ["person"],
"input_size": "640x640",
"tags": ["离轮", "人员", "安全"]
},
"bullet_frame": {
"name": "弹窗检测",
"description": "检测系统弹窗和异常界面",
"version": "1.0.0",
"detection_classes": ["window", "alert"],
"input_size": "640x640",
"tags": ["弹窗", "界面", "异常"]
}
}
db = next(get_db())
try:
# 获取algorithm目录的绝对路径
algorithm_dir = Path(__file__).parent.parent / "algorithm"
for algorithm_folder, config in algorithm_configs.items():
folder_path = algorithm_dir / algorithm_folder
if folder_path.exists():
# 检查模型文件
weights_path = folder_path / "weights" / "best.pt"
model_path = f"algorithm/{algorithm_folder}/weights/best.pt" if weights_path.exists() else None
# 检查是否已存在该算法
existing_algorithm = db.query(Algorithm).filter(
Algorithm.name == config["name"]
).first()
if existing_algorithm:
# 更新现有算法
existing_algorithm.model_path = model_path
existing_algorithm.detection_classes = json.dumps(config["detection_classes"])
existing_algorithm.input_size = config["input_size"]
existing_algorithm.tags = json.dumps(config["tags"])
existing_algorithm.status = "active" if model_path else "inactive"
print(f"更新算法: {config['name']}")
else:
# 创建新算法
algorithm = Algorithm(
name=config["name"],
description=config["description"],
version=config["version"],
model_path=model_path,
detection_classes=json.dumps(config["detection_classes"]),
input_size=config["input_size"],
tags=json.dumps(config["tags"]),
status="active" if model_path else "inactive",
accuracy=0.0, # 默认值
inference_time=0.0, # 默认值
is_enabled=True,
creator="system"
)
db.add(algorithm)
print(f"创建算法: {config['name']}")
db.commit()
print("算法数据初始化完成!")
except Exception as e:
db.rollback()
print(f"初始化算法数据失败: {e}")
raise
finally:
db.close()
if __name__ == "__main__":
init_algorithms()

View File

@ -19,61 +19,6 @@ def init_sample_data():
db = SessionLocal()
try:
# 创建示例算法
algorithms = [
{
"name": "YOLOv11n人员检测",
"description": "基于YOLOv11n的人员检测算法适用于边检场景",
"version": "1.0.0",
"model_path": "/models/yolo11n.pt",
"config_path": "/configs/yolo11n.yaml",
"status": "active",
"accuracy": 0.95,
"detection_classes": json.dumps(["person"]),
"input_size": "640x640",
"inference_time": 15.5,
"is_enabled": True,
"creator": "admin",
"tags": json.dumps(["person", "detection", "yolo"])
},
{
"name": "车辆检测算法",
"description": "专门用于车辆检测的深度学习算法",
"version": "2.1.0",
"model_path": "/models/vehicle_detection.pt",
"config_path": "/configs/vehicle.yaml",
"status": "active",
"accuracy": 0.92,
"detection_classes": json.dumps(["car", "truck", "bus", "motorcycle"]),
"input_size": "640x640",
"inference_time": 18.2,
"is_enabled": True,
"creator": "admin",
"tags": json.dumps(["vehicle", "detection"])
},
{
"name": "人脸识别算法",
"description": "高精度人脸识别算法",
"version": "1.5.0",
"model_path": "/models/face_recognition.pt",
"config_path": "/configs/face.yaml",
"status": "inactive",
"accuracy": 0.98,
"detection_classes": json.dumps(["face"]),
"input_size": "512x512",
"inference_time": 25.0,
"is_enabled": False,
"creator": "admin",
"tags": json.dumps(["face", "recognition"])
}
]
for alg_data in algorithms:
algorithm = Algorithm(**alg_data)
db.add(algorithm)
db.commit()
print("✅ 算法数据初始化完成")
# 创建示例设备
devices = [

View File

@ -24,3 +24,10 @@ class Device(BaseModel):
manufacturer = Column(String(100), comment="制造商")
model = Column(String(100), comment="设备型号")
serial_number = Column(String(100), comment="序列号")
# 新增字段:视频处理相关
demo_video_path = Column(String(500), comment="演示视频路径")
processed_video_path = Column(String(500), comment="处理结果视频路径")
processing_status = Column(String(20), default="idle", comment="处理状态: idle, processing, completed, failed")
processing_result = Column(Text, comment="处理结果JSON格式")
last_processed_at = Column(DateTime, comment="最后处理时间")

View File

@ -129,3 +129,4 @@ async def toggle_algorithm_enabled(
db.refresh(algorithm)
return algorithm.to_dict()

View File

@ -46,32 +46,47 @@ async def get_dashboard_kpi(db: Session = Depends(get_db)):
raise HTTPException(status_code=500, detail=f"获取KPI数据失败: {str(e)}")
@router.get("/alarm-trend", summary="获取告警趋势统计")
async def get_alarm_trend(
days: int = Query(7, ge=1, le=30, description="统计天数"),
db: Session = Depends(get_db)
):
"""获取告警趋势统计数据"""
async def get_alarm_trend(db: Session = Depends(get_db)):
"""获取告警趋势统计数据 - 最近5个月的P0-P3警告统计"""
try:
# TODO: 实现真实的告警趋势统计
# 当前返回模拟数据
end_date = datetime.now().date()
start_date = end_date - timedelta(days=days-1)
# 获取最近5个月的数据
current_date = datetime.now()
months = []
p0_warnings = []
p1_warnings = []
p2_warnings = []
p3_warnings = []
dates = []
alarms = []
resolved = []
for i in range(5):
# 计算月份
month_date = current_date - timedelta(days=30 * i)
month_str = month_date.strftime("%Y-%m")
months.append(month_str)
for i in range(days):
current_date = start_date + timedelta(days=i)
dates.append(current_date.strftime("%Y-%m-%d"))
# 模拟数据
alarms.append(10 + (i * 2) % 20)
resolved.append(8 + (i * 2) % 15)
# 模拟各级别警告数据
p0_count = 15 + (i * 3) % 25 # 15-40之间的随机数据
p1_count = 25 + (i * 2) % 35 # 25-60之间的随机数据
p2_count = 35 + (i * 4) % 45 # 35-80之间的随机数据
p3_count = 45 + (i * 3) % 55 # 45-100之间的随机数据
p0_warnings.append(p0_count)
p1_warnings.append(p1_count)
p2_warnings.append(p2_count)
p3_warnings.append(p3_count)
# 反转数组,让时间从早到晚
months.reverse()
p0_warnings.reverse()
p1_warnings.reverse()
p2_warnings.reverse()
p3_warnings.reverse()
return {
"dates": dates,
"alarms": alarms,
"resolved": resolved
"months": months,
"p0_warnings": p0_warnings,
"p1_warnings": p1_warnings,
"p2_warnings": p2_warnings,
"p3_warnings": p3_warnings
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取告警趋势数据失败: {str(e)}")

View File

@ -1,8 +1,13 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File, BackgroundTasks
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
import os
import shutil
from datetime import datetime
from pathlib import Path
from core.database import get_db
from models.device import Device
from services.video_processor import video_processor
router = APIRouter()
@ -75,6 +80,10 @@ async def get_devices(
"status": device.status,
"is_enabled": device.is_enabled,
"description": device.description,
"demo_video_path": device.demo_video_path,
"processed_video_path": device.processed_video_path,
"processing_status": device.processing_status,
"last_processed_at": device.last_processed_at.isoformat() if device.last_processed_at else None,
"created_at": device.created_at,
"updated_at": device.updated_at
})
@ -108,6 +117,11 @@ async def get_device(
"status": device.status,
"is_enabled": device.is_enabled,
"description": device.description,
"demo_video_path": device.demo_video_path,
"processed_video_path": device.processed_video_path,
"processing_status": device.processing_status,
"processing_result": device.processing_result,
"last_processed_at": device.last_processed_at.isoformat() if device.last_processed_at else None,
"created_at": device.created_at,
"updated_at": device.updated_at
}
@ -222,6 +236,125 @@ async def toggle_device_enabled(
"updated_at": device.updated_at
}
@router.post("/{device_id}/upload-demo-video", summary="上传演示视频")
async def upload_demo_video(
device_id: int,
video_file: UploadFile = File(...),
background_tasks: BackgroundTasks = None,
db: Session = Depends(get_db)
):
"""上传设备的演示视频并开始处理"""
# 检查设备是否存在
device = db.query(Device).filter(Device.id == device_id).first()
if not device:
raise HTTPException(status_code=404, detail="设备不存在")
# 检查设备是否关联了算法
if not device.algorithm_id:
raise HTTPException(status_code=400, detail="设备未关联算法,无法处理视频")
# 检查文件类型
if not video_file.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
raise HTTPException(status_code=400, detail="只支持视频文件格式")
# 创建上传目录
uploads_dir = Path(__file__).parent.parent / "uploads" / "videos"
uploads_dir.mkdir(parents=True, exist_ok=True)
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"demo_{device_id}_{timestamp}_{video_file.filename}"
file_path = uploads_dir / filename
try:
# 保存上传的文件
with open(file_path, "wb") as buffer:
shutil.copyfileobj(video_file.file, buffer)
# 更新设备记录
device.demo_video_path = str(file_path)
device.processing_status = "idle"
device.processed_video_path = None
device.processing_result = None
db.commit()
# 后台处理视频
if background_tasks:
background_tasks.add_task(
video_processor.process_device_video,
device_id,
str(file_path)
)
return {
"success": True,
"message": "演示视频上传成功,开始处理",
"device_id": device_id,
"demo_video_path": str(file_path),
"processing_status": "processing"
}
except Exception as e:
# 清理已上传的文件
if file_path.exists():
file_path.unlink()
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
@router.get("/{device_id}/processing-status", summary="获取视频处理状态")
async def get_video_processing_status(
device_id: int,
db: Session = Depends(get_db)
):
"""获取设备的视频处理状态"""
device = db.query(Device).filter(Device.id == device_id).first()
if not device:
raise HTTPException(status_code=404, detail="设备不存在")
return {
"device_id": device_id,
"processing_status": device.processing_status,
"demo_video_path": device.demo_video_path,
"processed_video_path": device.processed_video_path,
"last_processed_at": device.last_processed_at.isoformat() if device.last_processed_at else None,
"processing_result": device.processing_result
}
@router.post("/{device_id}/process-video", summary="手动触发视频处理")
async def process_device_video(
device_id: int,
background_tasks: BackgroundTasks = None,
db: Session = Depends(get_db)
):
"""手动触发设备的视频处理"""
device = db.query(Device).filter(Device.id == device_id).first()
if not device:
raise HTTPException(status_code=404, detail="设备不存在")
if not device.demo_video_path:
raise HTTPException(status_code=400, detail="设备没有上传演示视频")
if not device.algorithm_id:
raise HTTPException(status_code=400, detail="设备未关联算法")
# 检查演示视频文件是否存在
if not os.path.exists(device.demo_video_path):
raise HTTPException(status_code=400, detail="演示视频文件不存在")
# 后台处理视频
if background_tasks:
background_tasks.add_task(
video_processor.process_device_video,
device_id,
device.demo_video_path
)
return {
"success": True,
"message": "视频处理已开始",
"device_id": device_id,
"processing_status": "processing"
}
@router.get("/types/list", summary="获取设备类型列表")
async def get_device_types():
"""获取所有设备类型"""

View File

@ -0,0 +1 @@
# Services package

View File

@ -0,0 +1,182 @@
import os
import sys
import json
import asyncio
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional
import logging
# 添加algorithm目录到Python路径
sys.path.append(str(Path(__file__).parent.parent.parent / "algorithm"))
from detection import DetectionProcessor, SegmentationProcessor
from core.database import get_db
from models.device import Device
from models.algorithm import Algorithm
logger = logging.getLogger(__name__)
class VideoProcessor:
"""视频处理服务"""
def __init__(self):
self.detection_processor = None
self.segmentation_processor = None
def _get_processor(self, algorithm_id: int) -> Optional[DetectionProcessor]:
"""根据算法ID获取对应的处理器"""
db = next(get_db())
try:
algorithm = db.query(Algorithm).filter(Algorithm.id == algorithm_id).first()
if not algorithm or not algorithm.model_path:
logger.error(f"算法 {algorithm_id} 不存在或模型路径为空")
return None
# 检查模型文件是否存在
if not os.path.exists(algorithm.model_path):
logger.error(f"模型文件不存在: {algorithm.model_path}")
return None
# 根据算法名称判断使用检测还是分割
if "分割" in algorithm.name or "segmentation" in algorithm.name.lower():
if not self.segmentation_processor:
self.segmentation_processor = SegmentationProcessor(model_path=algorithm.model_path)
return self.segmentation_processor
else:
if not self.detection_processor:
self.detection_processor = DetectionProcessor(model_path=algorithm.model_path)
return self.detection_processor
except Exception as e:
logger.error(f"获取处理器失败: {e}")
return None
finally:
db.close()
async def process_device_video(self, device_id: int, demo_video_path: str) -> Dict[str, Any]:
"""
处理设备的演示视频
Args:
device_id: 设备ID
demo_video_path: 演示视频路径
Returns:
处理结果字典
"""
db = next(get_db())
try:
# 获取设备信息
device = db.query(Device).filter(Device.id == device_id).first()
if not device:
return {"success": False, "error": "设备不存在"}
if not device.algorithm_id:
return {"success": False, "error": "设备未关联算法"}
# 检查视频文件是否存在
if not os.path.exists(demo_video_path):
return {"success": False, "error": "演示视频文件不存在"}
# 更新设备状态为处理中
device.processing_status = "processing"
device.last_processed_at = datetime.now()
db.commit()
# 获取处理器
processor = self._get_processor(device.algorithm_id)
if not processor:
device.processing_status = "failed"
device.processing_result = json.dumps({"error": "无法获取处理器"})
db.commit()
return {"success": False, "error": "无法获取处理器"}
# 生成输出视频路径
uploads_dir = Path(__file__).parent.parent / "uploads" / "results"
uploads_dir.mkdir(parents=True, exist_ok=True)
output_filename = f"processed_{device_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4"
output_video_path = str(uploads_dir / output_filename)
# 获取算法配置
algorithm = db.query(Algorithm).filter(Algorithm.id == device.algorithm_id).first()
detection_classes = json.loads(algorithm.detection_classes) if algorithm.detection_classes else None
# 处理视频
logger.info(f"开始处理设备 {device_id} 的视频: {demo_video_path}")
if isinstance(processor, DetectionProcessor):
# 目标检测处理
result = processor.process_video(
input_video_path=demo_video_path,
output_video_path=output_video_path,
confidence_threshold=0.5,
classes=[0] if detection_classes and "person" in detection_classes else None,
show_live=False,
save_annotated=True
)
else:
# 图像分割处理
result = processor.process_video(
input_video_path=demo_video_path,
output_video_path=output_video_path,
confidence_threshold=0.3,
classes=[0] if detection_classes and "person" in detection_classes else None,
show_live=False,
save_annotated=True
)
# 更新设备状态和结果
if result.get("success"):
device.processing_status = "completed"
device.processed_video_path = output_video_path
device.processing_result = json.dumps(result)
device.last_processed_at = datetime.now()
logger.info(f"设备 {device_id} 视频处理完成: {output_video_path}")
else:
device.processing_status = "failed"
device.processing_result = json.dumps(result)
logger.error(f"设备 {device_id} 视频处理失败: {result.get('error', '未知错误')}")
db.commit()
return result
except Exception as e:
logger.error(f"处理设备 {device_id} 视频时发生错误: {e}")
try:
device.processing_status = "failed"
device.processing_result = json.dumps({"error": str(e)})
db.commit()
except:
pass
return {"success": False, "error": str(e)}
finally:
db.close()
def get_processing_status(self, device_id: int) -> Dict[str, Any]:
"""获取设备处理状态"""
db = next(get_db())
try:
device = db.query(Device).filter(Device.id == device_id).first()
if not device:
return {"success": False, "error": "设备不存在"}
return {
"success": True,
"device_id": device_id,
"processing_status": device.processing_status,
"demo_video_path": device.demo_video_path,
"processed_video_path": device.processed_video_path,
"last_processed_at": device.last_processed_at.isoformat() if device.last_processed_at else None,
"processing_result": json.loads(device.processing_result) if device.processing_result else None
}
except Exception as e:
logger.error(f"获取设备 {device_id} 处理状态时发生错误: {e}")
return {"success": False, "error": str(e)}
finally:
db.close()
# 全局视频处理器实例
video_processor = VideoProcessor()

View File

@ -1,145 +0,0 @@
#!/usr/bin/env python3
"""
API测试脚本
用于验证新创建的接口是否正常工作
"""
import requests
import json
from datetime import datetime
# API基础URL
BASE_URL = "http://localhost:8000/api"
def test_dashboard_apis():
"""测试仪表板相关接口"""
print("=== 测试仪表板接口 ===")
# 测试KPI接口
try:
response = requests.get(f"{BASE_URL}/dashboard/kpi")
print(f"KPI接口: {response.status_code}")
if response.status_code == 200:
print(f"KPI数据: {response.json()}")
except Exception as e:
print(f"KPI接口错误: {e}")
# 测试告警趋势接口
try:
response = requests.get(f"{BASE_URL}/dashboard/alarm-trend?days=7")
print(f"告警趋势接口: {response.status_code}")
if response.status_code == 200:
print(f"告警趋势数据: {response.json()}")
except Exception as e:
print(f"告警趋势接口错误: {e}")
# 测试摄像头统计接口
try:
response = requests.get(f"{BASE_URL}/dashboard/camera-stats")
print(f"摄像头统计接口: {response.status_code}")
if response.status_code == 200:
print(f"摄像头统计数据: {response.json()}")
except Exception as e:
print(f"摄像头统计接口错误: {e}")
def test_monitor_apis():
"""测试监控相关接口"""
print("\n=== 测试监控接口 ===")
# 测试监控列表接口
try:
response = requests.get(f"{BASE_URL}/monitors?page=1&size=10")
print(f"监控列表接口: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"监控列表: 总数={data.get('total', 0)}")
except Exception as e:
print(f"监控列表接口错误: {e}")
def test_alarm_apis():
"""测试告警相关接口"""
print("\n=== 测试告警接口 ===")
# 测试告警列表接口
try:
response = requests.get(f"{BASE_URL}/alarms?page=1&size=10")
print(f"告警列表接口: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"告警列表: 总数={data.get('total', 0)}")
except Exception as e:
print(f"告警列表接口错误: {e}")
# 测试告警统计接口
try:
response = requests.get(f"{BASE_URL}/alarms/stats")
print(f"告警统计接口: {response.status_code}")
if response.status_code == 200:
print(f"告警统计数据: {response.json()}")
except Exception as e:
print(f"告警统计接口错误: {e}")
def test_scene_apis():
"""测试场景相关接口"""
print("\n=== 测试场景接口 ===")
# 测试场景列表接口
try:
response = requests.get(f"{BASE_URL}/scenes")
print(f"场景列表接口: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"场景列表: 数量={len(data.get('scenes', []))}")
except Exception as e:
print(f"场景列表接口错误: {e}")
def test_auth_apis():
"""测试认证相关接口"""
print("\n=== 测试认证接口 ===")
# 测试登录接口
try:
login_data = {
"username": "admin",
"password": "admin123"
}
response = requests.post(f"{BASE_URL}/auth/login", data=login_data)
print(f"登录接口: {response.status_code}")
if response.status_code == 200:
print("登录成功")
else:
print(f"登录失败: {response.text}")
except Exception as e:
print(f"登录接口错误: {e}")
def test_upload_apis():
"""测试上传相关接口"""
print("\n=== 测试上传接口 ===")
# 测试上传统计接口
try:
response = requests.get(f"{BASE_URL}/upload/stats")
print(f"上传统计接口: {response.status_code}")
if response.status_code == 200:
print(f"上传统计数据: {response.json()}")
except Exception as e:
print(f"上传统计接口错误: {e}")
def main():
"""主测试函数"""
print("开始API测试...")
print(f"测试时间: {datetime.now()}")
print(f"API基础URL: {BASE_URL}")
# 测试各个模块的接口
test_dashboard_apis()
test_monitor_apis()
test_alarm_apis()
test_scene_apis()
test_auth_apis()
test_upload_apis()
print("\n=== 测试完成 ===")
if __name__ == "__main__":
main()

View File

@ -1,362 +1,15 @@
# 前端接口开发计划
## 概述
根据前端代码分析当前后端已实现基础的CRUD接口但前端需要更多统计、监控、告警等功能的接口。以下是缺失接口的开发计划。
## 已实现接口
- ✅ 设备管理CRUD操作
- ✅ 算法管理CRUD操作
- ✅ 事件管理CRUD操作
* [x] 初始化的算法数据和algorithm中的算法保持一致保证能被detection.py 的类使用
## 缺失接口清单
* [x] 设备支持上传演示视频,上传视频后处理完成后放到 uploads/results并在表中记录这两个视频地址
### 1. 仪表板统计接口 (优先级:高)
* [ ] 添加视频处理结果的下载接口
#### 1.1 主要KPI指标
```
GET /api/dashboard/kpi
响应:
{
"total_devices": 156,
"online_devices": 142,
"total_algorithms": 8,
"active_algorithms": 6,
"total_events": 1247,
"today_events": 89,
"alert_events": 23,
"resolved_events": 66
}
```
* [ ] 优化视频处理性能,支持大文件处理
#### 1.2 告警趋势统计
```
GET /api/dashboard/alarm-trend
参数:
- days: 7 (默认7天)
响应:
{
"dates": ["2024-01-01", "2024-01-02", ...],
"alarms": [12, 15, 8, 23, 18, 25, 20],
"resolved": [10, 12, 7, 19, 15, 22, 18]
}
```
* [ ] 添加视频处理进度实时反馈
#### 1.3 摄像头统计
```
GET /api/dashboard/camera-stats
响应:
{
"total_cameras": 156,
"online_cameras": 142,
"offline_cameras": 14,
"by_location": [
{"location": "港口区", "total": 45, "online": 42},
{"location": "码头区", "total": 38, "online": 35},
{"location": "办公区", "total": 23, "online": 21}
]
}
```
* [ ] 实现视频处理失败重试机制
#### 1.4 算法统计
```
GET /api/dashboard/algorithm-stats
响应:
{
"total_algorithms": 8,
"active_algorithms": 6,
"by_type": [
{"type": "目标检测", "count": 3, "accuracy": 95.2},
{"type": "行为识别", "count": 2, "accuracy": 88.7},
{"type": "越界检测", "count": 3, "accuracy": 92.1}
]
}
```
#### 1.5 事件热点统计
```
GET /api/dashboard/event-hotspots
响应:
{
"hotspots": [
{
"location": "港口A区",
"event_count": 45,
"severity": "high",
"coordinates": {"lat": 31.2304, "lng": 121.4737}
}
]
}
```
### 2. 监控管理接口 (优先级:高)
#### 2.1 监控列表
```
GET /api/monitors
参数:
- page: 1
- size: 20
- status: online/offline
- location: 位置筛选
响应:
{
"monitors": [
{
"id": 1,
"name": "港口区监控1",
"location": "港口A区",
"status": "online",
"video_url": "/videos/port-1.mp4",
"detections": [
{"type": "person", "x": 25, "y": 35, "width": 40, "height": 80}
]
}
],
"total": 156,
"page": 1,
"size": 20
}
```
#### 2.2 监控详情
```
GET /api/monitors/{monitor_id}
响应:
{
"id": 1,
"name": "港口区主监控",
"location": "港口区",
"status": "online",
"video_url": "/videos/port-main.mp4",
"detections": [...],
"events": [...],
"algorithms": [...]
}
```
### 3. 告警管理接口 (优先级:中)
#### 3.1 告警列表
```
GET /api/alarms
参数:
- page: 1
- size: 20
- severity: high/medium/low
- status: pending/resolved
- start_time: 2024-01-01
- end_time: 2024-01-31
响应:
{
"alarms": [
{
"id": 1,
"type": "船舶靠泊",
"severity": "high",
"status": "pending",
"device": "港口区监控1",
"created_at": "2024-01-15T10:30:00Z",
"description": "检测到船舶靠泊行为"
}
],
"total": 89,
"page": 1,
"size": 20
}
```
#### 3.2 告警处理
```
PATCH /api/alarms/{alarm_id}/resolve
请求体:
{
"resolution_notes": "已确认船舶靠泊,无异常",
"resolved_by": "operator1"
}
```
#### 3.3 告警统计
```
GET /api/alarms/stats
响应:
{
"total_alarms": 89,
"pending_alarms": 23,
"resolved_alarms": 66,
"by_severity": [
{"severity": "high", "count": 12},
{"severity": "medium", "count": 45},
{"severity": "low", "count": 32}
]
}
```
### 4. 场景管理接口 (优先级:中)
#### 4.1 场景列表
```
GET /api/scenes
响应:
{
"scenes": [
{
"id": "scene-001",
"name": "港口区场景",
"description": "港口区监控场景",
"device_count": 45,
"algorithm_count": 3
}
]
}
```
#### 4.2 场景详情
```
GET /api/scenes/{scene_id}
响应:
{
"id": "scene-001",
"name": "港口区场景",
"description": "港口区监控场景",
"devices": [...],
"algorithms": [...],
"events": [...]
}
```
### 5. 文件上传接口 (优先级:中)
#### 5.1 视频上传
```
POST /api/upload/video
Content-Type: multipart/form-data
请求体:
- file: 视频文件
- device_id: 设备ID
- description: 描述
响应:
{
"file_id": "video_123",
"file_url": "/uploads/videos/video_123.mp4",
"file_size": 1024000,
"duration": 30.5
}
```
#### 5.2 图片上传
```
POST /api/upload/image
Content-Type: multipart/form-data
请求体:
- file: 图片文件
- event_id: 事件ID
响应:
{
"file_id": "image_456",
"file_url": "/uploads/images/image_456.jpg",
"file_size": 256000
}
```
### 6. 用户认证接口 (优先级:低)
#### 6.1 用户登录
```
POST /api/auth/login
请求体:
{
"username": "admin",
"password": "password123"
}
响应:
{
"access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...",
"token_type": "bearer",
"expires_in": 3600,
"user": {
"id": 1,
"username": "admin",
"role": "admin"
}
}
```
#### 6.2 用户信息
```
GET /api/auth/profile
响应:
{
"id": 1,
"username": "admin",
"email": "admin@example.com",
"role": "admin",
"permissions": ["read", "write", "admin"]
}
```
## 开发进度
### ✅ 已完成 (第一阶段)
1. ✅ 仪表板统计接口 (KPI、告警趋势、摄像头统计、算法统计、事件热点)
2. ✅ 监控管理接口 (监控列表、监控详情)
### ✅ 已完成 (第二阶段)
1. ✅ 告警管理接口 (告警列表、告警处理、告警统计)
2. ✅ 场景管理接口 (场景列表、场景详情)
### ✅ 已完成 (第三阶段)
1. ✅ 文件上传接口 (视频上传、图片上传)
2. ✅ 用户认证接口 (登录、用户信息)
### 🔄 待优化功能
1. 实现真实的告警趋势统计 (当前使用模拟数据)
2. 实现真实的检测数据获取 (当前使用模拟数据)
3. 实现真实的视频流URL生成
4. 实现真实的JWT验证中间件
5. 实现真实的场景管理数据库模型
6. 实现真实的文件删除逻辑
7. 添加Redis缓存支持
8. 添加WebSocket实时数据推送
## 技术实现要点
1. **数据库模型扩展**
- 添加统计相关的视图或缓存表
- 优化查询性能,添加索引
2. **缓存策略**
- 使用Redis缓存统计数据
- 设置合理的缓存过期时间
3. **文件存储**
- 配置静态文件服务
- 实现文件上传和存储逻辑
4. **权限控制**
- 实现JWT认证
- 添加角色和权限控制
5. **WebSocket支持**
- 实时监控数据推送
- 告警实时通知
## 测试计划
1. **单元测试**每个接口的CRUD操作
2. **集成测试**:前后端联调
3. **性能测试**:大数据量下的响应时间
4. **安全测试**:认证和权限验证
## 部署计划
1. **开发环境**:本地测试
2. **测试环境**:功能验证
3. **生产环境**:正式部署
## 注意事项
1. 所有接口需要添加错误处理和日志记录
2. 敏感数据需要加密存储
3. 文件上传需要限制文件大小和类型
4. 统计数据需要定期更新,避免过期数据
5. 接口文档需要及时更新
* [ ]