diff --git a/server/.gitignore b/.gitignore similarity index 100% rename from server/.gitignore rename to .gitignore diff --git a/algorithm/__init__.py b/algorithm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithm/detection.py b/algorithm/detection.py new file mode 100644 index 0000000..6a7b7b2 --- /dev/null +++ b/algorithm/detection.py @@ -0,0 +1,543 @@ +import cv2 +import supervision as sv +from ultralytics import YOLO +import os +import time +import numpy as np +from tqdm import tqdm +from PIL import Image, ImageDraw, ImageFont +import io +from typing import List, Tuple, Optional, Dict, Any +import json + + +class DetectionProcessor: + """ + 目标检测处理器类 + 用于处理视频中的目标检测任务 + """ + + def __init__(self, model_path: str = "yolov11n.pt"): + """ + 初始化检测处理器 + + Args: + model_path: YOLO模型路径 + """ + self.model_path = model_path + self.model = None + self._load_model() + + def _load_model(self): + """加载YOLO模型""" + try: + print(f"加载检测模型: {self.model_path}") + self.model = YOLO(self.model_path) + except Exception as e: + print(f"模型加载失败: {e}") + raise + + def _put_chinese_text(self, img, text, position, font_size=20, color=(0, 0, 255)): + """在图像上添加中文文本""" + img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(img_pil) + + try: + font = ImageFont.truetype("msyh.ttc", font_size) + except: + font = ImageFont.load_default() + + draw.text(position, text, font=font, fill=color) + return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) + + def process_video( + self, + input_video_path: str, + output_video_path: str, + confidence_threshold: float = 0.5, + classes: Optional[List[int]] = None, + show_live: bool = False, + save_annotated: bool = True, + warning_zones: Optional[List[List[Tuple[float, float]]]] = None, + ) -> Dict[str, Any]: + """ + 处理视频进行目标检测 + + Args: + input_video_path: 输入视频路径 + output_video_path: 输出视频路径 + confidence_threshold: 置信度阈值 + classes: 要检测的类别ID列表 + show_live: 是否实时显示 + save_annotated: 是否保存标注视频 + warning_zones: 警告区域多边形点列表 + + Returns: + 处理结果字典 + """ + if not os.path.exists(input_video_path): + raise FileNotFoundError(f"输入视频文件不存在: {input_video_path}") + + # 初始化视频读取器 + video_info = sv.VideoInfo.from_video_path(input_video_path) + cap = cv2.VideoCapture(input_video_path) + + if not cap.isOpened(): + raise RuntimeError(f"无法打开视频文件: {input_video_path}") + + # 初始化视频写入器 + writer = None + if save_annotated: + output_dir = os.path.dirname(output_video_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + + writer = cv2.VideoWriter( + output_video_path, + cv2.VideoWriter_fourcc(*"mp4v"), + video_info.fps, + (video_info.width, video_info.height), + ) + + # 初始化Supervision工具 + box_annotator = sv.BoxAnnotator() + label_annotator = sv.LabelAnnotator() + + # 处理统计 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_count = 0 + processing_times = [] + detection_results = [] + warning_events = [] + + pbar = tqdm(total=total_frames, desc="处理视频帧") + + while cap.isOpened(): + start_time = time.time() + + ret, frame = cap.read() + if not ret: + break + + frame_count += 1 + + # 目标检测 + results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0] + + detections = sv.Detections.from_ultralytics(results) + + # 准备标注 + labels = [ + f"{results.names[class_id]} {confidence:.2f}" + for class_id, confidence in zip(detections.class_id, detections.confidence) + ] + + # 标注边界框 + annotated_frame = box_annotator.annotate(scene=frame.copy(), detections=detections) + + # 标注标签 + annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) + + # 处理警告区域 + if warning_zones: + for i, zone in enumerate(warning_zones): + pts = [(int(x), int(y)) for x, y in zone] + color = (0, 255, 0) if i == 0 else (0, 0, 255) + cv2.polylines(annotated_frame, [np.array(pts)], isClosed=True, color=color, thickness=2) + + # 检测目标是否在警告区域内 + polygon = np.array(pts, np.int32) + for bbox in detections.xyxy: + x1, y1, x2, y2 = bbox + center_x = int((x1 + x2) / 2) + center_y = int((y1 + y2) / 2) + + distance = cv2.pointPolygonTest(polygon, (center_x, center_y), False) + if distance >= 0: + warning_text = f"警告: 目标进入区域{i+1}!" + annotated_frame = self._put_chinese_text( + annotated_frame, warning_text, (50, 50 + i * 30), font_size=20, color=(255, 0, 0) + ) + warning_events.append( + {"frame": frame_count, "zone": i + 1, "position": (center_x, center_y)} + ) + + # 记录检测结果 + frame_detections = [] + for i, (bbox, class_id, confidence) in enumerate( + zip(detections.xyxy, detections.class_id, detections.confidence) + ): + frame_detections.append( + { + "bbox": bbox.tolist(), + "class_id": int(class_id), + "class_name": results.names[int(class_id)], + "confidence": float(confidence), + } + ) + + detection_results.append({"frame": frame_count, "detections": frame_detections}) + + # 计算处理时间 + end_time = time.time() + processing_time = end_time - start_time + processing_times.append(processing_time) + + # 实时显示 + if show_live: + cv2.imshow("Detection", annotated_frame) + if cv2.waitKey(1) == 27: + break + + # 保存帧 + if save_annotated and writer: + writer.write(annotated_frame) + + pbar.update(1) + + # 清理资源 + cap.release() + if writer: + writer.release() + if show_live: + cv2.destroyAllWindows() + pbar.close() + + # 计算统计信息 + avg_time = sum(processing_times) / len(processing_times) if processing_times else 0 + fps = 1 / avg_time if avg_time > 0 else 0 + + return { + "success": True, + "total_frames": frame_count, + "avg_processing_time": avg_time, + "fps": fps, + "detection_results": detection_results, + "warning_events": warning_events, + "output_video_path": output_video_path if save_annotated else None, + } + + def process_image( + self, + image_path: str, + confidence_threshold: float = 0.5, + classes: Optional[List[int]] = None, + warning_zones: Optional[List[List[Tuple[float, float]]]] = None, + ) -> Dict[str, Any]: + """ + 处理单张图像进行目标检测 + + Args: + image_path: 输入图像路径 + confidence_threshold: 置信度阈值 + classes: 要检测的类别ID列表 + warning_zones: 警告区域多边形点列表 + + Returns: + 检测结果字典 + """ + if not os.path.exists(image_path): + raise FileNotFoundError(f"输入图像文件不存在: {image_path}") + + # 读取图像 + frame = cv2.imread(image_path) + if frame is None: + raise RuntimeError(f"无法读取图像文件: {image_path}") + + # 目标检测 + results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0] + + detections = sv.Detections.from_ultralytics(results) + + # 准备标注 + labels = [ + f"{results.names[class_id]} {confidence:.2f}" + for class_id, confidence in zip(detections.class_id, detections.confidence) + ] + + # 标注边界框 + annotated_frame = sv.BoxAnnotator().annotate(scene=frame.copy(), detections=detections) + + # 标注标签 + annotated_frame = sv.LabelAnnotator().annotate(scene=annotated_frame, detections=detections, labels=labels) + + # 处理警告区域 + warning_events = [] + if warning_zones: + for i, zone in enumerate(warning_zones): + pts = [(int(x), int(y)) for x, y in zone] + color = (0, 255, 0) if i == 0 else (0, 0, 255) + cv2.polylines(annotated_frame, [np.array(pts)], isClosed=True, color=color, thickness=2) + + polygon = np.array(pts, np.int32) + for bbox in detections.xyxy: + x1, y1, x2, y2 = bbox + center_x = int((x1 + x2) / 2) + center_y = int((y1 + y2) / 2) + + distance = cv2.pointPolygonTest(polygon, (center_x, center_y), False) + if distance >= 0: + warning_text = f"警告: 目标进入区域{i+1}!" + annotated_frame = self._put_chinese_text( + annotated_frame, warning_text, (50, 50 + i * 30), font_size=20, color=(255, 0, 0) + ) + warning_events.append({"zone": i + 1, "position": (center_x, center_y)}) + + # 记录检测结果 + detection_results = [] + for bbox, class_id, confidence in zip(detections.xyxy, detections.class_id, detections.confidence): + detection_results.append( + { + "bbox": bbox.tolist(), + "class_id": int(class_id), + "class_name": results.names[int(class_id)], + "confidence": float(confidence), + } + ) + + return { + "success": True, + "detection_results": detection_results, + "warning_events": warning_events, + "annotated_image": annotated_frame, + } + + +class SegmentationProcessor: + """ + 图像分割处理器类 + 用于处理视频中的图像分割任务 + """ + + def __init__(self, model_path: str = "yolov8n-seg.pt"): + """ + 初始化分割处理器 + + Args: + model_path: YOLO分割模型路径 + """ + self.model_path = model_path + self.model = None + self._load_model() + + def _load_model(self): + """加载YOLO分割模型""" + try: + print(f"加载分割模型: {self.model_path}") + self.model = YOLO(self.model_path) + except Exception as e: + print(f"模型加载失败: {e}") + raise + + def process_video( + self, + input_video_path: str, + output_video_path: str, + confidence_threshold: float = 0.5, + classes: Optional[List[int]] = None, + show_live: bool = False, + save_annotated: bool = True, + ) -> Dict[str, Any]: + """ + 处理视频进行图像分割 + + Args: + input_video_path: 输入视频路径 + output_video_path: 输出视频路径 + confidence_threshold: 置信度阈值 + classes: 要检测的类别ID列表 + show_live: 是否实时显示 + save_annotated: 是否保存标注视频 + + Returns: + 处理结果字典 + """ + if not os.path.exists(input_video_path): + raise FileNotFoundError(f"输入视频文件不存在: {input_video_path}") + + # 初始化视频读取器 + video_info = sv.VideoInfo.from_video_path(input_video_path) + cap = cv2.VideoCapture(input_video_path) + + if not cap.isOpened(): + raise RuntimeError(f"无法打开视频文件: {input_video_path}") + + # 初始化视频写入器 + writer = None + if save_annotated: + output_dir = os.path.dirname(output_video_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + + writer = cv2.VideoWriter( + output_video_path, + cv2.VideoWriter_fourcc(*"mp4v"), + video_info.fps, + (video_info.width, video_info.height), + ) + + # 初始化Supervision工具 + mask_annotator = sv.MaskAnnotator(color=sv.Color(r=0, g=255, b=0), opacity=0.5) + + # 处理统计 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_count = 0 + processing_times = [] + segmentation_results = [] + + pbar = tqdm(total=total_frames, desc="处理视频帧") + + while cap.isOpened(): + start_time = time.time() + + ret, frame = cap.read() + if not ret: + break + + frame_count += 1 + + # 图像分割 + results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0] + + detections = sv.Detections.from_ultralytics(results) + + # 应用分割掩码 + segmented_frame = mask_annotator.annotate(scene=frame.copy(), detections=detections) + + # 记录分割结果 + frame_segmentations = [] + for i, (bbox, class_id, confidence) in enumerate( + zip(detections.xyxy, detections.class_id, detections.confidence) + ): + frame_segmentations.append( + { + "bbox": bbox.tolist(), + "class_id": int(class_id), + "class_name": results.names[int(class_id)], + "confidence": float(confidence), + } + ) + + segmentation_results.append({"frame": frame_count, "segmentations": frame_segmentations}) + + # 计算处理时间 + end_time = time.time() + processing_time = end_time - start_time + processing_times.append(processing_time) + + # 实时显示 + if show_live: + cv2.imshow("Segmentation", segmented_frame) + if cv2.waitKey(1) == 27: + break + + # 保存帧 + if save_annotated and writer: + writer.write(segmented_frame) + + pbar.update(1) + + # 清理资源 + cap.release() + if writer: + writer.release() + if show_live: + cv2.destroyAllWindows() + pbar.close() + + # 计算统计信息 + avg_time = sum(processing_times) / len(processing_times) if processing_times else 0 + fps = 1 / avg_time if avg_time > 0 else 0 + + return { + "success": True, + "total_frames": frame_count, + "avg_processing_time": avg_time, + "fps": fps, + "segmentation_results": segmentation_results, + "output_video_path": output_video_path if save_annotated else None, + } + + def process_image( + self, image_path: str, confidence_threshold: float = 0.5, classes: Optional[List[int]] = None + ) -> Dict[str, Any]: + """ + 处理单张图像进行图像分割 + + Args: + image_path: 输入图像路径 + confidence_threshold: 置信度阈值 + classes: 要检测的类别ID列表 + + Returns: + 分割结果字典 + """ + if not os.path.exists(image_path): + raise FileNotFoundError(f"输入图像文件不存在: {image_path}") + + # 读取图像 + frame = cv2.imread(image_path) + if frame is None: + raise RuntimeError(f"无法读取图像文件: {image_path}") + + # 图像分割 + results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0] + + detections = sv.Detections.from_ultralytics(results) + + # 应用分割掩码 + mask_annotator = sv.MaskAnnotator(color=sv.Color(r=0, g=255, b=0), opacity=0.5) + + segmented_frame = mask_annotator.annotate(scene=frame.copy(), detections=detections) + + # 记录分割结果 + segmentation_results = [] + for bbox, class_id, confidence in zip(detections.xyxy, detections.class_id, detections.confidence): + segmentation_results.append( + { + "bbox": bbox.tolist(), + "class_id": int(class_id), + "class_name": results.names[int(class_id)], + "confidence": float(confidence), + } + ) + + return {"success": True, "segmentation_results": segmentation_results, "segmented_image": segmented_frame} + + +# 示例用法 +if __name__ == "__main__": + # 目标检测示例 + detector = DetectionProcessor(model_path="yolov11n.pt") + + # 定义警告区域 + warning_zones = [ + [(1153.11, 273.86), (1146.09, 370.77), (1217.71, 352.51), (1220.52, 257.01)], + [(1140.47, 506.99), (1214.90, 504.19), (1212.09, 592.66), (1136.25, 594.07)], + ] + + # 处理视频 + result = detector.process_video( + input_video_path="input_video.mp4", + output_video_path="output_detection.mp4", + confidence_threshold=0.5, + classes=[0], # 只检测人 + warning_zones=warning_zones, + save_annotated=True, + ) + + print("检测结果:", result) + + # 图像分割示例 + segmenter = SegmentationProcessor(model_path="yolov8n-seg.pt") + + # 处理视频 + seg_result = segmenter.process_video( + input_video_path="input_video.mp4", + output_video_path="output_segmentation.mp4", + confidence_threshold=0.3, + classes=[0], # 只检测人 + save_annotated=True, + ) + + print("分割结果:", seg_result) diff --git a/server/border_inspection.db b/server/border_inspection.db new file mode 100644 index 0000000..ee43dbe Binary files /dev/null and b/server/border_inspection.db differ diff --git a/server/init_algorithms.py b/server/init_algorithms.py new file mode 100644 index 0000000..20e39f3 --- /dev/null +++ b/server/init_algorithms.py @@ -0,0 +1,122 @@ +import os +import sys +import json +from pathlib import Path + +# 添加项目根目录到Python路径 +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from server.core.database import get_db +from server.models.algorithm import Algorithm +from models.base import Base +import os +from core.database import engine + +# 创建数据库表 +Base.metadata.create_all(bind=engine) + +def init_algorithms(): + """初始化算法数据,与algorithm目录保持一致""" + + # 算法配置映射 + algorithm_configs = { + "berthing": { + "name": "靠泊检测", + "description": "检测船舶靠泊过程中的关键行为", + "version": "1.0.0", + "detection_classes": ["person", "ship"], + "input_size": "640x640", + "tags": ["靠泊", "船舶", "安全"] + }, + "unberthing": { + "name": "离泊检测", + "description": "检测船舶离泊过程中的关键行为", + "version": "1.0.0", + "detection_classes": ["person", "ship"], + "input_size": "640x640", + "tags": ["离泊", "船舶", "安全"] + }, + "board_ship": { + "name": "登轮检测", + "description": "检测人员登轮行为", + "version": "1.0.0", + "detection_classes": ["person"], + "input_size": "640x640", + "tags": ["登轮", "人员", "安全"] + }, + "leave_ship": { + "name": "离轮检测", + "description": "检测人员离轮行为", + "version": "1.0.0", + "detection_classes": ["person"], + "input_size": "640x640", + "tags": ["离轮", "人员", "安全"] + }, + "bullet_frame": { + "name": "弹窗检测", + "description": "检测系统弹窗和异常界面", + "version": "1.0.0", + "detection_classes": ["window", "alert"], + "input_size": "640x640", + "tags": ["弹窗", "界面", "异常"] + } + } + + db = next(get_db()) + + try: + # 获取algorithm目录的绝对路径 + algorithm_dir = Path(__file__).parent.parent / "algorithm" + + for algorithm_folder, config in algorithm_configs.items(): + folder_path = algorithm_dir / algorithm_folder + + if folder_path.exists(): + # 检查模型文件 + weights_path = folder_path / "weights" / "best.pt" + model_path = f"algorithm/{algorithm_folder}/weights/best.pt" if weights_path.exists() else None + + # 检查是否已存在该算法 + existing_algorithm = db.query(Algorithm).filter( + Algorithm.name == config["name"] + ).first() + + if existing_algorithm: + # 更新现有算法 + existing_algorithm.model_path = model_path + existing_algorithm.detection_classes = json.dumps(config["detection_classes"]) + existing_algorithm.input_size = config["input_size"] + existing_algorithm.tags = json.dumps(config["tags"]) + existing_algorithm.status = "active" if model_path else "inactive" + print(f"更新算法: {config['name']}") + else: + # 创建新算法 + algorithm = Algorithm( + name=config["name"], + description=config["description"], + version=config["version"], + model_path=model_path, + detection_classes=json.dumps(config["detection_classes"]), + input_size=config["input_size"], + tags=json.dumps(config["tags"]), + status="active" if model_path else "inactive", + accuracy=0.0, # 默认值 + inference_time=0.0, # 默认值 + is_enabled=True, + creator="system" + ) + db.add(algorithm) + print(f"创建算法: {config['name']}") + + db.commit() + print("算法数据初始化完成!") + + except Exception as e: + db.rollback() + print(f"初始化算法数据失败: {e}") + raise + finally: + db.close() + +if __name__ == "__main__": + init_algorithms() \ No newline at end of file diff --git a/server/init_data.py b/server/init_data.py index 9a1bd15..03c09c7 100644 --- a/server/init_data.py +++ b/server/init_data.py @@ -19,61 +19,6 @@ def init_sample_data(): db = SessionLocal() try: - # 创建示例算法 - algorithms = [ - { - "name": "YOLOv11n人员检测", - "description": "基于YOLOv11n的人员检测算法,适用于边检场景", - "version": "1.0.0", - "model_path": "/models/yolo11n.pt", - "config_path": "/configs/yolo11n.yaml", - "status": "active", - "accuracy": 0.95, - "detection_classes": json.dumps(["person"]), - "input_size": "640x640", - "inference_time": 15.5, - "is_enabled": True, - "creator": "admin", - "tags": json.dumps(["person", "detection", "yolo"]) - }, - { - "name": "车辆检测算法", - "description": "专门用于车辆检测的深度学习算法", - "version": "2.1.0", - "model_path": "/models/vehicle_detection.pt", - "config_path": "/configs/vehicle.yaml", - "status": "active", - "accuracy": 0.92, - "detection_classes": json.dumps(["car", "truck", "bus", "motorcycle"]), - "input_size": "640x640", - "inference_time": 18.2, - "is_enabled": True, - "creator": "admin", - "tags": json.dumps(["vehicle", "detection"]) - }, - { - "name": "人脸识别算法", - "description": "高精度人脸识别算法", - "version": "1.5.0", - "model_path": "/models/face_recognition.pt", - "config_path": "/configs/face.yaml", - "status": "inactive", - "accuracy": 0.98, - "detection_classes": json.dumps(["face"]), - "input_size": "512x512", - "inference_time": 25.0, - "is_enabled": False, - "creator": "admin", - "tags": json.dumps(["face", "recognition"]) - } - ] - - for alg_data in algorithms: - algorithm = Algorithm(**alg_data) - db.add(algorithm) - - db.commit() - print("✅ 算法数据初始化完成") # 创建示例设备 devices = [ diff --git a/server/models/device.py b/server/models/device.py index 0aa036e..02a5465 100644 --- a/server/models/device.py +++ b/server/models/device.py @@ -23,4 +23,11 @@ class Device(BaseModel): description = Column(Text, comment="设备描述") manufacturer = Column(String(100), comment="制造商") model = Column(String(100), comment="设备型号") - serial_number = Column(String(100), comment="序列号") \ No newline at end of file + serial_number = Column(String(100), comment="序列号") + + # 新增字段:视频处理相关 + demo_video_path = Column(String(500), comment="演示视频路径") + processed_video_path = Column(String(500), comment="处理结果视频路径") + processing_status = Column(String(20), default="idle", comment="处理状态: idle, processing, completed, failed") + processing_result = Column(Text, comment="处理结果,JSON格式") + last_processed_at = Column(DateTime, comment="最后处理时间") \ No newline at end of file diff --git a/server/routers/algorithms.py b/server/routers/algorithms.py index bf2c5fe..e2ff144 100644 --- a/server/routers/algorithms.py +++ b/server/routers/algorithms.py @@ -128,4 +128,5 @@ async def toggle_algorithm_enabled( db.commit() db.refresh(algorithm) - return algorithm.to_dict() \ No newline at end of file + return algorithm.to_dict() + diff --git a/server/routers/dashboard.py b/server/routers/dashboard.py index c8842cd..c6e0329 100644 --- a/server/routers/dashboard.py +++ b/server/routers/dashboard.py @@ -46,32 +46,47 @@ async def get_dashboard_kpi(db: Session = Depends(get_db)): raise HTTPException(status_code=500, detail=f"获取KPI数据失败: {str(e)}") @router.get("/alarm-trend", summary="获取告警趋势统计") -async def get_alarm_trend( - days: int = Query(7, ge=1, le=30, description="统计天数"), - db: Session = Depends(get_db) -): - """获取告警趋势统计数据""" +async def get_alarm_trend(db: Session = Depends(get_db)): + """获取告警趋势统计数据 - 最近5个月的P0-P3警告统计""" try: - # TODO: 实现真实的告警趋势统计 - # 当前返回模拟数据 - end_date = datetime.now().date() - start_date = end_date - timedelta(days=days-1) + # 获取最近5个月的数据 + current_date = datetime.now() + months = [] + p0_warnings = [] + p1_warnings = [] + p2_warnings = [] + p3_warnings = [] - dates = [] - alarms = [] - resolved = [] + for i in range(5): + # 计算月份 + month_date = current_date - timedelta(days=30 * i) + month_str = month_date.strftime("%Y-%m") + months.append(month_str) + + # 模拟各级别警告数据 + p0_count = 15 + (i * 3) % 25 # 15-40之间的随机数据 + p1_count = 25 + (i * 2) % 35 # 25-60之间的随机数据 + p2_count = 35 + (i * 4) % 45 # 35-80之间的随机数据 + p3_count = 45 + (i * 3) % 55 # 45-100之间的随机数据 + + p0_warnings.append(p0_count) + p1_warnings.append(p1_count) + p2_warnings.append(p2_count) + p3_warnings.append(p3_count) - for i in range(days): - current_date = start_date + timedelta(days=i) - dates.append(current_date.strftime("%Y-%m-%d")) - # 模拟数据 - alarms.append(10 + (i * 2) % 20) - resolved.append(8 + (i * 2) % 15) + # 反转数组,让时间从早到晚 + months.reverse() + p0_warnings.reverse() + p1_warnings.reverse() + p2_warnings.reverse() + p3_warnings.reverse() return { - "dates": dates, - "alarms": alarms, - "resolved": resolved + "months": months, + "p0_warnings": p0_warnings, + "p1_warnings": p1_warnings, + "p2_warnings": p2_warnings, + "p3_warnings": p3_warnings } except Exception as e: raise HTTPException(status_code=500, detail=f"获取告警趋势数据失败: {str(e)}") diff --git a/server/routers/devices.py b/server/routers/devices.py index 971d1b1..223c5d4 100644 --- a/server/routers/devices.py +++ b/server/routers/devices.py @@ -1,8 +1,13 @@ -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File, BackgroundTasks from sqlalchemy.orm import Session from typing import List, Optional, Dict, Any +import os +import shutil +from datetime import datetime +from pathlib import Path from core.database import get_db from models.device import Device +from services.video_processor import video_processor router = APIRouter() @@ -75,6 +80,10 @@ async def get_devices( "status": device.status, "is_enabled": device.is_enabled, "description": device.description, + "demo_video_path": device.demo_video_path, + "processed_video_path": device.processed_video_path, + "processing_status": device.processing_status, + "last_processed_at": device.last_processed_at.isoformat() if device.last_processed_at else None, "created_at": device.created_at, "updated_at": device.updated_at }) @@ -108,6 +117,11 @@ async def get_device( "status": device.status, "is_enabled": device.is_enabled, "description": device.description, + "demo_video_path": device.demo_video_path, + "processed_video_path": device.processed_video_path, + "processing_status": device.processing_status, + "processing_result": device.processing_result, + "last_processed_at": device.last_processed_at.isoformat() if device.last_processed_at else None, "created_at": device.created_at, "updated_at": device.updated_at } @@ -222,6 +236,125 @@ async def toggle_device_enabled( "updated_at": device.updated_at } +@router.post("/{device_id}/upload-demo-video", summary="上传演示视频") +async def upload_demo_video( + device_id: int, + video_file: UploadFile = File(...), + background_tasks: BackgroundTasks = None, + db: Session = Depends(get_db) +): + """上传设备的演示视频并开始处理""" + # 检查设备是否存在 + device = db.query(Device).filter(Device.id == device_id).first() + if not device: + raise HTTPException(status_code=404, detail="设备不存在") + + # 检查设备是否关联了算法 + if not device.algorithm_id: + raise HTTPException(status_code=400, detail="设备未关联算法,无法处理视频") + + # 检查文件类型 + if not video_file.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')): + raise HTTPException(status_code=400, detail="只支持视频文件格式") + + # 创建上传目录 + uploads_dir = Path(__file__).parent.parent / "uploads" / "videos" + uploads_dir.mkdir(parents=True, exist_ok=True) + + # 生成文件名 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"demo_{device_id}_{timestamp}_{video_file.filename}" + file_path = uploads_dir / filename + + try: + # 保存上传的文件 + with open(file_path, "wb") as buffer: + shutil.copyfileobj(video_file.file, buffer) + + # 更新设备记录 + device.demo_video_path = str(file_path) + device.processing_status = "idle" + device.processed_video_path = None + device.processing_result = None + db.commit() + + # 后台处理视频 + if background_tasks: + background_tasks.add_task( + video_processor.process_device_video, + device_id, + str(file_path) + ) + + return { + "success": True, + "message": "演示视频上传成功,开始处理", + "device_id": device_id, + "demo_video_path": str(file_path), + "processing_status": "processing" + } + + except Exception as e: + # 清理已上传的文件 + if file_path.exists(): + file_path.unlink() + raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}") + +@router.get("/{device_id}/processing-status", summary="获取视频处理状态") +async def get_video_processing_status( + device_id: int, + db: Session = Depends(get_db) +): + """获取设备的视频处理状态""" + device = db.query(Device).filter(Device.id == device_id).first() + if not device: + raise HTTPException(status_code=404, detail="设备不存在") + + return { + "device_id": device_id, + "processing_status": device.processing_status, + "demo_video_path": device.demo_video_path, + "processed_video_path": device.processed_video_path, + "last_processed_at": device.last_processed_at.isoformat() if device.last_processed_at else None, + "processing_result": device.processing_result + } + +@router.post("/{device_id}/process-video", summary="手动触发视频处理") +async def process_device_video( + device_id: int, + background_tasks: BackgroundTasks = None, + db: Session = Depends(get_db) +): + """手动触发设备的视频处理""" + device = db.query(Device).filter(Device.id == device_id).first() + if not device: + raise HTTPException(status_code=404, detail="设备不存在") + + if not device.demo_video_path: + raise HTTPException(status_code=400, detail="设备没有上传演示视频") + + if not device.algorithm_id: + raise HTTPException(status_code=400, detail="设备未关联算法") + + # 检查演示视频文件是否存在 + if not os.path.exists(device.demo_video_path): + raise HTTPException(status_code=400, detail="演示视频文件不存在") + + # 后台处理视频 + if background_tasks: + background_tasks.add_task( + video_processor.process_device_video, + device_id, + device.demo_video_path + ) + + return { + "success": True, + "message": "视频处理已开始", + "device_id": device_id, + "processing_status": "processing" + } + @router.get("/types/list", summary="获取设备类型列表") async def get_device_types(): """获取所有设备类型""" diff --git a/server/services/__init__.py b/server/services/__init__.py new file mode 100644 index 0000000..9983534 --- /dev/null +++ b/server/services/__init__.py @@ -0,0 +1 @@ +# Services package \ No newline at end of file diff --git a/server/services/video_processor.py b/server/services/video_processor.py new file mode 100644 index 0000000..5410aa6 --- /dev/null +++ b/server/services/video_processor.py @@ -0,0 +1,182 @@ +import os +import sys +import json +import asyncio +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, Optional +import logging + +# 添加algorithm目录到Python路径 +sys.path.append(str(Path(__file__).parent.parent.parent / "algorithm")) + +from detection import DetectionProcessor, SegmentationProcessor +from core.database import get_db +from models.device import Device +from models.algorithm import Algorithm + +logger = logging.getLogger(__name__) + +class VideoProcessor: + """视频处理服务""" + + def __init__(self): + self.detection_processor = None + self.segmentation_processor = None + + def _get_processor(self, algorithm_id: int) -> Optional[DetectionProcessor]: + """根据算法ID获取对应的处理器""" + db = next(get_db()) + try: + algorithm = db.query(Algorithm).filter(Algorithm.id == algorithm_id).first() + if not algorithm or not algorithm.model_path: + logger.error(f"算法 {algorithm_id} 不存在或模型路径为空") + return None + + # 检查模型文件是否存在 + if not os.path.exists(algorithm.model_path): + logger.error(f"模型文件不存在: {algorithm.model_path}") + return None + + # 根据算法名称判断使用检测还是分割 + if "分割" in algorithm.name or "segmentation" in algorithm.name.lower(): + if not self.segmentation_processor: + self.segmentation_processor = SegmentationProcessor(model_path=algorithm.model_path) + return self.segmentation_processor + else: + if not self.detection_processor: + self.detection_processor = DetectionProcessor(model_path=algorithm.model_path) + return self.detection_processor + + except Exception as e: + logger.error(f"获取处理器失败: {e}") + return None + finally: + db.close() + + async def process_device_video(self, device_id: int, demo_video_path: str) -> Dict[str, Any]: + """ + 处理设备的演示视频 + + Args: + device_id: 设备ID + demo_video_path: 演示视频路径 + + Returns: + 处理结果字典 + """ + db = next(get_db()) + + try: + # 获取设备信息 + device = db.query(Device).filter(Device.id == device_id).first() + if not device: + return {"success": False, "error": "设备不存在"} + + if not device.algorithm_id: + return {"success": False, "error": "设备未关联算法"} + + # 检查视频文件是否存在 + if not os.path.exists(demo_video_path): + return {"success": False, "error": "演示视频文件不存在"} + + # 更新设备状态为处理中 + device.processing_status = "processing" + device.last_processed_at = datetime.now() + db.commit() + + # 获取处理器 + processor = self._get_processor(device.algorithm_id) + if not processor: + device.processing_status = "failed" + device.processing_result = json.dumps({"error": "无法获取处理器"}) + db.commit() + return {"success": False, "error": "无法获取处理器"} + + # 生成输出视频路径 + uploads_dir = Path(__file__).parent.parent / "uploads" / "results" + uploads_dir.mkdir(parents=True, exist_ok=True) + + output_filename = f"processed_{device_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4" + output_video_path = str(uploads_dir / output_filename) + + # 获取算法配置 + algorithm = db.query(Algorithm).filter(Algorithm.id == device.algorithm_id).first() + detection_classes = json.loads(algorithm.detection_classes) if algorithm.detection_classes else None + + # 处理视频 + logger.info(f"开始处理设备 {device_id} 的视频: {demo_video_path}") + + if isinstance(processor, DetectionProcessor): + # 目标检测处理 + result = processor.process_video( + input_video_path=demo_video_path, + output_video_path=output_video_path, + confidence_threshold=0.5, + classes=[0] if detection_classes and "person" in detection_classes else None, + show_live=False, + save_annotated=True + ) + else: + # 图像分割处理 + result = processor.process_video( + input_video_path=demo_video_path, + output_video_path=output_video_path, + confidence_threshold=0.3, + classes=[0] if detection_classes and "person" in detection_classes else None, + show_live=False, + save_annotated=True + ) + + # 更新设备状态和结果 + if result.get("success"): + device.processing_status = "completed" + device.processed_video_path = output_video_path + device.processing_result = json.dumps(result) + device.last_processed_at = datetime.now() + logger.info(f"设备 {device_id} 视频处理完成: {output_video_path}") + else: + device.processing_status = "failed" + device.processing_result = json.dumps(result) + logger.error(f"设备 {device_id} 视频处理失败: {result.get('error', '未知错误')}") + + db.commit() + return result + + except Exception as e: + logger.error(f"处理设备 {device_id} 视频时发生错误: {e}") + try: + device.processing_status = "failed" + device.processing_result = json.dumps({"error": str(e)}) + db.commit() + except: + pass + return {"success": False, "error": str(e)} + finally: + db.close() + + def get_processing_status(self, device_id: int) -> Dict[str, Any]: + """获取设备处理状态""" + db = next(get_db()) + try: + device = db.query(Device).filter(Device.id == device_id).first() + if not device: + return {"success": False, "error": "设备不存在"} + + return { + "success": True, + "device_id": device_id, + "processing_status": device.processing_status, + "demo_video_path": device.demo_video_path, + "processed_video_path": device.processed_video_path, + "last_processed_at": device.last_processed_at.isoformat() if device.last_processed_at else None, + "processing_result": json.loads(device.processing_result) if device.processing_result else None + } + except Exception as e: + logger.error(f"获取设备 {device_id} 处理状态时发生错误: {e}") + return {"success": False, "error": str(e)} + finally: + db.close() + +# 全局视频处理器实例 +video_processor = VideoProcessor() \ No newline at end of file diff --git a/server/test_api.py b/server/test_api.py deleted file mode 100644 index 70c475d..0000000 --- a/server/test_api.py +++ /dev/null @@ -1,145 +0,0 @@ -#!/usr/bin/env python3 -""" -API测试脚本 -用于验证新创建的接口是否正常工作 -""" - -import requests -import json -from datetime import datetime - -# API基础URL -BASE_URL = "http://localhost:8000/api" - -def test_dashboard_apis(): - """测试仪表板相关接口""" - print("=== 测试仪表板接口 ===") - - # 测试KPI接口 - try: - response = requests.get(f"{BASE_URL}/dashboard/kpi") - print(f"KPI接口: {response.status_code}") - if response.status_code == 200: - print(f"KPI数据: {response.json()}") - except Exception as e: - print(f"KPI接口错误: {e}") - - # 测试告警趋势接口 - try: - response = requests.get(f"{BASE_URL}/dashboard/alarm-trend?days=7") - print(f"告警趋势接口: {response.status_code}") - if response.status_code == 200: - print(f"告警趋势数据: {response.json()}") - except Exception as e: - print(f"告警趋势接口错误: {e}") - - # 测试摄像头统计接口 - try: - response = requests.get(f"{BASE_URL}/dashboard/camera-stats") - print(f"摄像头统计接口: {response.status_code}") - if response.status_code == 200: - print(f"摄像头统计数据: {response.json()}") - except Exception as e: - print(f"摄像头统计接口错误: {e}") - -def test_monitor_apis(): - """测试监控相关接口""" - print("\n=== 测试监控接口 ===") - - # 测试监控列表接口 - try: - response = requests.get(f"{BASE_URL}/monitors?page=1&size=10") - print(f"监控列表接口: {response.status_code}") - if response.status_code == 200: - data = response.json() - print(f"监控列表: 总数={data.get('total', 0)}") - except Exception as e: - print(f"监控列表接口错误: {e}") - -def test_alarm_apis(): - """测试告警相关接口""" - print("\n=== 测试告警接口 ===") - - # 测试告警列表接口 - try: - response = requests.get(f"{BASE_URL}/alarms?page=1&size=10") - print(f"告警列表接口: {response.status_code}") - if response.status_code == 200: - data = response.json() - print(f"告警列表: 总数={data.get('total', 0)}") - except Exception as e: - print(f"告警列表接口错误: {e}") - - # 测试告警统计接口 - try: - response = requests.get(f"{BASE_URL}/alarms/stats") - print(f"告警统计接口: {response.status_code}") - if response.status_code == 200: - print(f"告警统计数据: {response.json()}") - except Exception as e: - print(f"告警统计接口错误: {e}") - -def test_scene_apis(): - """测试场景相关接口""" - print("\n=== 测试场景接口 ===") - - # 测试场景列表接口 - try: - response = requests.get(f"{BASE_URL}/scenes") - print(f"场景列表接口: {response.status_code}") - if response.status_code == 200: - data = response.json() - print(f"场景列表: 数量={len(data.get('scenes', []))}") - except Exception as e: - print(f"场景列表接口错误: {e}") - -def test_auth_apis(): - """测试认证相关接口""" - print("\n=== 测试认证接口 ===") - - # 测试登录接口 - try: - login_data = { - "username": "admin", - "password": "admin123" - } - response = requests.post(f"{BASE_URL}/auth/login", data=login_data) - print(f"登录接口: {response.status_code}") - if response.status_code == 200: - print("登录成功") - else: - print(f"登录失败: {response.text}") - except Exception as e: - print(f"登录接口错误: {e}") - -def test_upload_apis(): - """测试上传相关接口""" - print("\n=== 测试上传接口 ===") - - # 测试上传统计接口 - try: - response = requests.get(f"{BASE_URL}/upload/stats") - print(f"上传统计接口: {response.status_code}") - if response.status_code == 200: - print(f"上传统计数据: {response.json()}") - except Exception as e: - print(f"上传统计接口错误: {e}") - -def main(): - """主测试函数""" - print("开始API测试...") - print(f"测试时间: {datetime.now()}") - print(f"API基础URL: {BASE_URL}") - - # 测试各个模块的接口 - test_dashboard_apis() - test_monitor_apis() - test_alarm_apis() - test_scene_apis() - test_auth_apis() - test_upload_apis() - - print("\n=== 测试完成 ===") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/server/todo.md b/server/todo.md index 0a75b65..56afcef 100644 --- a/server/todo.md +++ b/server/todo.md @@ -1,362 +1,15 @@ -# 前端接口开发计划 -## 概述 -根据前端代码分析,当前后端已实现基础的CRUD接口,但前端需要更多统计、监控、告警等功能的接口。以下是缺失接口的开发计划。 -## 已实现接口 -- ✅ 设备管理:CRUD操作 -- ✅ 算法管理:CRUD操作 -- ✅ 事件管理:CRUD操作 +* [x] 初始化的算法数据和algorithm中的算法保持一致,保证能被detection.py 的类使用 -## 缺失接口清单 +* [x] 设备支持上传演示视频,上传视频后处理完成后放到 uploads/results,并在表中记录这两个视频地址 -### 1. 仪表板统计接口 (优先级:高) +* [ ] 添加视频处理结果的下载接口 -#### 1.1 主要KPI指标 -``` -GET /api/dashboard/kpi -响应: -{ - "total_devices": 156, - "online_devices": 142, - "total_algorithms": 8, - "active_algorithms": 6, - "total_events": 1247, - "today_events": 89, - "alert_events": 23, - "resolved_events": 66 -} -``` +* [ ] 优化视频处理性能,支持大文件处理 -#### 1.2 告警趋势统计 -``` -GET /api/dashboard/alarm-trend -参数: -- days: 7 (默认7天) -响应: -{ - "dates": ["2024-01-01", "2024-01-02", ...], - "alarms": [12, 15, 8, 23, 18, 25, 20], - "resolved": [10, 12, 7, 19, 15, 22, 18] -} -``` +* [ ] 添加视频处理进度实时反馈 -#### 1.3 摄像头统计 -``` -GET /api/dashboard/camera-stats -响应: -{ - "total_cameras": 156, - "online_cameras": 142, - "offline_cameras": 14, - "by_location": [ - {"location": "港口区", "total": 45, "online": 42}, - {"location": "码头区", "total": 38, "online": 35}, - {"location": "办公区", "total": 23, "online": 21} - ] -} -``` +* [ ] 实现视频处理失败重试机制 -#### 1.4 算法统计 -``` -GET /api/dashboard/algorithm-stats -响应: -{ - "total_algorithms": 8, - "active_algorithms": 6, - "by_type": [ - {"type": "目标检测", "count": 3, "accuracy": 95.2}, - {"type": "行为识别", "count": 2, "accuracy": 88.7}, - {"type": "越界检测", "count": 3, "accuracy": 92.1} - ] -} -``` - -#### 1.5 事件热点统计 -``` -GET /api/dashboard/event-hotspots -响应: -{ - "hotspots": [ - { - "location": "港口A区", - "event_count": 45, - "severity": "high", - "coordinates": {"lat": 31.2304, "lng": 121.4737} - } - ] -} -``` - -### 2. 监控管理接口 (优先级:高) - -#### 2.1 监控列表 -``` -GET /api/monitors -参数: -- page: 1 -- size: 20 -- status: online/offline -- location: 位置筛选 -响应: -{ - "monitors": [ - { - "id": 1, - "name": "港口区监控1", - "location": "港口A区", - "status": "online", - "video_url": "/videos/port-1.mp4", - "detections": [ - {"type": "person", "x": 25, "y": 35, "width": 40, "height": 80} - ] - } - ], - "total": 156, - "page": 1, - "size": 20 -} -``` - -#### 2.2 监控详情 -``` -GET /api/monitors/{monitor_id} -响应: -{ - "id": 1, - "name": "港口区主监控", - "location": "港口区", - "status": "online", - "video_url": "/videos/port-main.mp4", - "detections": [...], - "events": [...], - "algorithms": [...] -} -``` - -### 3. 告警管理接口 (优先级:中) - -#### 3.1 告警列表 -``` -GET /api/alarms -参数: -- page: 1 -- size: 20 -- severity: high/medium/low -- status: pending/resolved -- start_time: 2024-01-01 -- end_time: 2024-01-31 -响应: -{ - "alarms": [ - { - "id": 1, - "type": "船舶靠泊", - "severity": "high", - "status": "pending", - "device": "港口区监控1", - "created_at": "2024-01-15T10:30:00Z", - "description": "检测到船舶靠泊行为" - } - ], - "total": 89, - "page": 1, - "size": 20 -} -``` - -#### 3.2 告警处理 -``` -PATCH /api/alarms/{alarm_id}/resolve -请求体: -{ - "resolution_notes": "已确认船舶靠泊,无异常", - "resolved_by": "operator1" -} -``` - -#### 3.3 告警统计 -``` -GET /api/alarms/stats -响应: -{ - "total_alarms": 89, - "pending_alarms": 23, - "resolved_alarms": 66, - "by_severity": [ - {"severity": "high", "count": 12}, - {"severity": "medium", "count": 45}, - {"severity": "low", "count": 32} - ] -} -``` - -### 4. 场景管理接口 (优先级:中) - -#### 4.1 场景列表 -``` -GET /api/scenes -响应: -{ - "scenes": [ - { - "id": "scene-001", - "name": "港口区场景", - "description": "港口区监控场景", - "device_count": 45, - "algorithm_count": 3 - } - ] -} -``` - -#### 4.2 场景详情 -``` -GET /api/scenes/{scene_id} -响应: -{ - "id": "scene-001", - "name": "港口区场景", - "description": "港口区监控场景", - "devices": [...], - "algorithms": [...], - "events": [...] -} -``` - -### 5. 文件上传接口 (优先级:中) - -#### 5.1 视频上传 -``` -POST /api/upload/video -Content-Type: multipart/form-data -请求体: -- file: 视频文件 -- device_id: 设备ID -- description: 描述 -响应: -{ - "file_id": "video_123", - "file_url": "/uploads/videos/video_123.mp4", - "file_size": 1024000, - "duration": 30.5 -} -``` - -#### 5.2 图片上传 -``` -POST /api/upload/image -Content-Type: multipart/form-data -请求体: -- file: 图片文件 -- event_id: 事件ID -响应: -{ - "file_id": "image_456", - "file_url": "/uploads/images/image_456.jpg", - "file_size": 256000 -} -``` - -### 6. 用户认证接口 (优先级:低) - -#### 6.1 用户登录 -``` -POST /api/auth/login -请求体: -{ - "username": "admin", - "password": "password123" -} -响应: -{ - "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...", - "token_type": "bearer", - "expires_in": 3600, - "user": { - "id": 1, - "username": "admin", - "role": "admin" - } -} -``` - -#### 6.2 用户信息 -``` -GET /api/auth/profile -响应: -{ - "id": 1, - "username": "admin", - "email": "admin@example.com", - "role": "admin", - "permissions": ["read", "write", "admin"] -} -``` - -## 开发进度 - -### ✅ 已完成 (第一阶段) -1. ✅ 仪表板统计接口 (KPI、告警趋势、摄像头统计、算法统计、事件热点) -2. ✅ 监控管理接口 (监控列表、监控详情) - -### ✅ 已完成 (第二阶段) -1. ✅ 告警管理接口 (告警列表、告警处理、告警统计) -2. ✅ 场景管理接口 (场景列表、场景详情) - -### ✅ 已完成 (第三阶段) -1. ✅ 文件上传接口 (视频上传、图片上传) -2. ✅ 用户认证接口 (登录、用户信息) - -### 🔄 待优化功能 -1. 实现真实的告警趋势统计 (当前使用模拟数据) -2. 实现真实的检测数据获取 (当前使用模拟数据) -3. 实现真实的视频流URL生成 -4. 实现真实的JWT验证中间件 -5. 实现真实的场景管理数据库模型 -6. 实现真实的文件删除逻辑 -7. 添加Redis缓存支持 -8. 添加WebSocket实时数据推送 - -## 技术实现要点 - -1. **数据库模型扩展**: - - 添加统计相关的视图或缓存表 - - 优化查询性能,添加索引 - -2. **缓存策略**: - - 使用Redis缓存统计数据 - - 设置合理的缓存过期时间 - -3. **文件存储**: - - 配置静态文件服务 - - 实现文件上传和存储逻辑 - -4. **权限控制**: - - 实现JWT认证 - - 添加角色和权限控制 - -5. **WebSocket支持**: - - 实时监控数据推送 - - 告警实时通知 - -## 测试计划 - -1. **单元测试**:每个接口的CRUD操作 -2. **集成测试**:前后端联调 -3. **性能测试**:大数据量下的响应时间 -4. **安全测试**:认证和权限验证 - -## 部署计划 - -1. **开发环境**:本地测试 -2. **测试环境**:功能验证 -3. **生产环境**:正式部署 - -## 注意事项 - -1. 所有接口需要添加错误处理和日志记录 -2. 敏感数据需要加密存储 -3. 文件上传需要限制文件大小和类型 -4. 统计数据需要定期更新,避免过期数据 -5. 接口文档需要及时更新 +* [ ] \ No newline at end of file