544 lines
18 KiB
Python
544 lines
18 KiB
Python
![]() |
import cv2
|
||
|
import supervision as sv
|
||
|
from ultralytics import YOLO
|
||
|
import os
|
||
|
import time
|
||
|
import numpy as np
|
||
|
from tqdm import tqdm
|
||
|
from PIL import Image, ImageDraw, ImageFont
|
||
|
import io
|
||
|
from typing import List, Tuple, Optional, Dict, Any
|
||
|
import json
|
||
|
|
||
|
|
||
|
class DetectionProcessor:
|
||
|
"""
|
||
|
目标检测处理器类
|
||
|
用于处理视频中的目标检测任务
|
||
|
"""
|
||
|
|
||
|
def __init__(self, model_path: str = "yolov11n.pt"):
|
||
|
"""
|
||
|
初始化检测处理器
|
||
|
|
||
|
Args:
|
||
|
model_path: YOLO模型路径
|
||
|
"""
|
||
|
self.model_path = model_path
|
||
|
self.model = None
|
||
|
self._load_model()
|
||
|
|
||
|
def _load_model(self):
|
||
|
"""加载YOLO模型"""
|
||
|
try:
|
||
|
print(f"加载检测模型: {self.model_path}")
|
||
|
self.model = YOLO(self.model_path)
|
||
|
except Exception as e:
|
||
|
print(f"模型加载失败: {e}")
|
||
|
raise
|
||
|
|
||
|
def _put_chinese_text(self, img, text, position, font_size=20, color=(0, 0, 255)):
|
||
|
"""在图像上添加中文文本"""
|
||
|
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||
|
draw = ImageDraw.Draw(img_pil)
|
||
|
|
||
|
try:
|
||
|
font = ImageFont.truetype("msyh.ttc", font_size)
|
||
|
except:
|
||
|
font = ImageFont.load_default()
|
||
|
|
||
|
draw.text(position, text, font=font, fill=color)
|
||
|
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
||
|
|
||
|
def process_video(
|
||
|
self,
|
||
|
input_video_path: str,
|
||
|
output_video_path: str,
|
||
|
confidence_threshold: float = 0.5,
|
||
|
classes: Optional[List[int]] = None,
|
||
|
show_live: bool = False,
|
||
|
save_annotated: bool = True,
|
||
|
warning_zones: Optional[List[List[Tuple[float, float]]]] = None,
|
||
|
) -> Dict[str, Any]:
|
||
|
"""
|
||
|
处理视频进行目标检测
|
||
|
|
||
|
Args:
|
||
|
input_video_path: 输入视频路径
|
||
|
output_video_path: 输出视频路径
|
||
|
confidence_threshold: 置信度阈值
|
||
|
classes: 要检测的类别ID列表
|
||
|
show_live: 是否实时显示
|
||
|
save_annotated: 是否保存标注视频
|
||
|
warning_zones: 警告区域多边形点列表
|
||
|
|
||
|
Returns:
|
||
|
处理结果字典
|
||
|
"""
|
||
|
if not os.path.exists(input_video_path):
|
||
|
raise FileNotFoundError(f"输入视频文件不存在: {input_video_path}")
|
||
|
|
||
|
# 初始化视频读取器
|
||
|
video_info = sv.VideoInfo.from_video_path(input_video_path)
|
||
|
cap = cv2.VideoCapture(input_video_path)
|
||
|
|
||
|
if not cap.isOpened():
|
||
|
raise RuntimeError(f"无法打开视频文件: {input_video_path}")
|
||
|
|
||
|
# 初始化视频写入器
|
||
|
writer = None
|
||
|
if save_annotated:
|
||
|
output_dir = os.path.dirname(output_video_path)
|
||
|
if output_dir and not os.path.exists(output_dir):
|
||
|
os.makedirs(output_dir)
|
||
|
|
||
|
writer = cv2.VideoWriter(
|
||
|
output_video_path,
|
||
|
cv2.VideoWriter_fourcc(*"mp4v"),
|
||
|
video_info.fps,
|
||
|
(video_info.width, video_info.height),
|
||
|
)
|
||
|
|
||
|
# 初始化Supervision工具
|
||
|
box_annotator = sv.BoxAnnotator()
|
||
|
label_annotator = sv.LabelAnnotator()
|
||
|
|
||
|
# 处理统计
|
||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
frame_count = 0
|
||
|
processing_times = []
|
||
|
detection_results = []
|
||
|
warning_events = []
|
||
|
|
||
|
pbar = tqdm(total=total_frames, desc="处理视频帧")
|
||
|
|
||
|
while cap.isOpened():
|
||
|
start_time = time.time()
|
||
|
|
||
|
ret, frame = cap.read()
|
||
|
if not ret:
|
||
|
break
|
||
|
|
||
|
frame_count += 1
|
||
|
|
||
|
# 目标检测
|
||
|
results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0]
|
||
|
|
||
|
detections = sv.Detections.from_ultralytics(results)
|
||
|
|
||
|
# 准备标注
|
||
|
labels = [
|
||
|
f"{results.names[class_id]} {confidence:.2f}"
|
||
|
for class_id, confidence in zip(detections.class_id, detections.confidence)
|
||
|
]
|
||
|
|
||
|
# 标注边界框
|
||
|
annotated_frame = box_annotator.annotate(scene=frame.copy(), detections=detections)
|
||
|
|
||
|
# 标注标签
|
||
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||
|
|
||
|
# 处理警告区域
|
||
|
if warning_zones:
|
||
|
for i, zone in enumerate(warning_zones):
|
||
|
pts = [(int(x), int(y)) for x, y in zone]
|
||
|
color = (0, 255, 0) if i == 0 else (0, 0, 255)
|
||
|
cv2.polylines(annotated_frame, [np.array(pts)], isClosed=True, color=color, thickness=2)
|
||
|
|
||
|
# 检测目标是否在警告区域内
|
||
|
polygon = np.array(pts, np.int32)
|
||
|
for bbox in detections.xyxy:
|
||
|
x1, y1, x2, y2 = bbox
|
||
|
center_x = int((x1 + x2) / 2)
|
||
|
center_y = int((y1 + y2) / 2)
|
||
|
|
||
|
distance = cv2.pointPolygonTest(polygon, (center_x, center_y), False)
|
||
|
if distance >= 0:
|
||
|
warning_text = f"警告: 目标进入区域{i+1}!"
|
||
|
annotated_frame = self._put_chinese_text(
|
||
|
annotated_frame, warning_text, (50, 50 + i * 30), font_size=20, color=(255, 0, 0)
|
||
|
)
|
||
|
warning_events.append(
|
||
|
{"frame": frame_count, "zone": i + 1, "position": (center_x, center_y)}
|
||
|
)
|
||
|
|
||
|
# 记录检测结果
|
||
|
frame_detections = []
|
||
|
for i, (bbox, class_id, confidence) in enumerate(
|
||
|
zip(detections.xyxy, detections.class_id, detections.confidence)
|
||
|
):
|
||
|
frame_detections.append(
|
||
|
{
|
||
|
"bbox": bbox.tolist(),
|
||
|
"class_id": int(class_id),
|
||
|
"class_name": results.names[int(class_id)],
|
||
|
"confidence": float(confidence),
|
||
|
}
|
||
|
)
|
||
|
|
||
|
detection_results.append({"frame": frame_count, "detections": frame_detections})
|
||
|
|
||
|
# 计算处理时间
|
||
|
end_time = time.time()
|
||
|
processing_time = end_time - start_time
|
||
|
processing_times.append(processing_time)
|
||
|
|
||
|
# 实时显示
|
||
|
if show_live:
|
||
|
cv2.imshow("Detection", annotated_frame)
|
||
|
if cv2.waitKey(1) == 27:
|
||
|
break
|
||
|
|
||
|
# 保存帧
|
||
|
if save_annotated and writer:
|
||
|
writer.write(annotated_frame)
|
||
|
|
||
|
pbar.update(1)
|
||
|
|
||
|
# 清理资源
|
||
|
cap.release()
|
||
|
if writer:
|
||
|
writer.release()
|
||
|
if show_live:
|
||
|
cv2.destroyAllWindows()
|
||
|
pbar.close()
|
||
|
|
||
|
# 计算统计信息
|
||
|
avg_time = sum(processing_times) / len(processing_times) if processing_times else 0
|
||
|
fps = 1 / avg_time if avg_time > 0 else 0
|
||
|
|
||
|
return {
|
||
|
"success": True,
|
||
|
"total_frames": frame_count,
|
||
|
"avg_processing_time": avg_time,
|
||
|
"fps": fps,
|
||
|
"detection_results": detection_results,
|
||
|
"warning_events": warning_events,
|
||
|
"output_video_path": output_video_path if save_annotated else None,
|
||
|
}
|
||
|
|
||
|
def process_image(
|
||
|
self,
|
||
|
image_path: str,
|
||
|
confidence_threshold: float = 0.5,
|
||
|
classes: Optional[List[int]] = None,
|
||
|
warning_zones: Optional[List[List[Tuple[float, float]]]] = None,
|
||
|
) -> Dict[str, Any]:
|
||
|
"""
|
||
|
处理单张图像进行目标检测
|
||
|
|
||
|
Args:
|
||
|
image_path: 输入图像路径
|
||
|
confidence_threshold: 置信度阈值
|
||
|
classes: 要检测的类别ID列表
|
||
|
warning_zones: 警告区域多边形点列表
|
||
|
|
||
|
Returns:
|
||
|
检测结果字典
|
||
|
"""
|
||
|
if not os.path.exists(image_path):
|
||
|
raise FileNotFoundError(f"输入图像文件不存在: {image_path}")
|
||
|
|
||
|
# 读取图像
|
||
|
frame = cv2.imread(image_path)
|
||
|
if frame is None:
|
||
|
raise RuntimeError(f"无法读取图像文件: {image_path}")
|
||
|
|
||
|
# 目标检测
|
||
|
results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0]
|
||
|
|
||
|
detections = sv.Detections.from_ultralytics(results)
|
||
|
|
||
|
# 准备标注
|
||
|
labels = [
|
||
|
f"{results.names[class_id]} {confidence:.2f}"
|
||
|
for class_id, confidence in zip(detections.class_id, detections.confidence)
|
||
|
]
|
||
|
|
||
|
# 标注边界框
|
||
|
annotated_frame = sv.BoxAnnotator().annotate(scene=frame.copy(), detections=detections)
|
||
|
|
||
|
# 标注标签
|
||
|
annotated_frame = sv.LabelAnnotator().annotate(scene=annotated_frame, detections=detections, labels=labels)
|
||
|
|
||
|
# 处理警告区域
|
||
|
warning_events = []
|
||
|
if warning_zones:
|
||
|
for i, zone in enumerate(warning_zones):
|
||
|
pts = [(int(x), int(y)) for x, y in zone]
|
||
|
color = (0, 255, 0) if i == 0 else (0, 0, 255)
|
||
|
cv2.polylines(annotated_frame, [np.array(pts)], isClosed=True, color=color, thickness=2)
|
||
|
|
||
|
polygon = np.array(pts, np.int32)
|
||
|
for bbox in detections.xyxy:
|
||
|
x1, y1, x2, y2 = bbox
|
||
|
center_x = int((x1 + x2) / 2)
|
||
|
center_y = int((y1 + y2) / 2)
|
||
|
|
||
|
distance = cv2.pointPolygonTest(polygon, (center_x, center_y), False)
|
||
|
if distance >= 0:
|
||
|
warning_text = f"警告: 目标进入区域{i+1}!"
|
||
|
annotated_frame = self._put_chinese_text(
|
||
|
annotated_frame, warning_text, (50, 50 + i * 30), font_size=20, color=(255, 0, 0)
|
||
|
)
|
||
|
warning_events.append({"zone": i + 1, "position": (center_x, center_y)})
|
||
|
|
||
|
# 记录检测结果
|
||
|
detection_results = []
|
||
|
for bbox, class_id, confidence in zip(detections.xyxy, detections.class_id, detections.confidence):
|
||
|
detection_results.append(
|
||
|
{
|
||
|
"bbox": bbox.tolist(),
|
||
|
"class_id": int(class_id),
|
||
|
"class_name": results.names[int(class_id)],
|
||
|
"confidence": float(confidence),
|
||
|
}
|
||
|
)
|
||
|
|
||
|
return {
|
||
|
"success": True,
|
||
|
"detection_results": detection_results,
|
||
|
"warning_events": warning_events,
|
||
|
"annotated_image": annotated_frame,
|
||
|
}
|
||
|
|
||
|
|
||
|
class SegmentationProcessor:
|
||
|
"""
|
||
|
图像分割处理器类
|
||
|
用于处理视频中的图像分割任务
|
||
|
"""
|
||
|
|
||
|
def __init__(self, model_path: str = "yolov8n-seg.pt"):
|
||
|
"""
|
||
|
初始化分割处理器
|
||
|
|
||
|
Args:
|
||
|
model_path: YOLO分割模型路径
|
||
|
"""
|
||
|
self.model_path = model_path
|
||
|
self.model = None
|
||
|
self._load_model()
|
||
|
|
||
|
def _load_model(self):
|
||
|
"""加载YOLO分割模型"""
|
||
|
try:
|
||
|
print(f"加载分割模型: {self.model_path}")
|
||
|
self.model = YOLO(self.model_path)
|
||
|
except Exception as e:
|
||
|
print(f"模型加载失败: {e}")
|
||
|
raise
|
||
|
|
||
|
def process_video(
|
||
|
self,
|
||
|
input_video_path: str,
|
||
|
output_video_path: str,
|
||
|
confidence_threshold: float = 0.5,
|
||
|
classes: Optional[List[int]] = None,
|
||
|
show_live: bool = False,
|
||
|
save_annotated: bool = True,
|
||
|
) -> Dict[str, Any]:
|
||
|
"""
|
||
|
处理视频进行图像分割
|
||
|
|
||
|
Args:
|
||
|
input_video_path: 输入视频路径
|
||
|
output_video_path: 输出视频路径
|
||
|
confidence_threshold: 置信度阈值
|
||
|
classes: 要检测的类别ID列表
|
||
|
show_live: 是否实时显示
|
||
|
save_annotated: 是否保存标注视频
|
||
|
|
||
|
Returns:
|
||
|
处理结果字典
|
||
|
"""
|
||
|
if not os.path.exists(input_video_path):
|
||
|
raise FileNotFoundError(f"输入视频文件不存在: {input_video_path}")
|
||
|
|
||
|
# 初始化视频读取器
|
||
|
video_info = sv.VideoInfo.from_video_path(input_video_path)
|
||
|
cap = cv2.VideoCapture(input_video_path)
|
||
|
|
||
|
if not cap.isOpened():
|
||
|
raise RuntimeError(f"无法打开视频文件: {input_video_path}")
|
||
|
|
||
|
# 初始化视频写入器
|
||
|
writer = None
|
||
|
if save_annotated:
|
||
|
output_dir = os.path.dirname(output_video_path)
|
||
|
if output_dir and not os.path.exists(output_dir):
|
||
|
os.makedirs(output_dir)
|
||
|
|
||
|
writer = cv2.VideoWriter(
|
||
|
output_video_path,
|
||
|
cv2.VideoWriter_fourcc(*"mp4v"),
|
||
|
video_info.fps,
|
||
|
(video_info.width, video_info.height),
|
||
|
)
|
||
|
|
||
|
# 初始化Supervision工具
|
||
|
mask_annotator = sv.MaskAnnotator(color=sv.Color(r=0, g=255, b=0), opacity=0.5)
|
||
|
|
||
|
# 处理统计
|
||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
frame_count = 0
|
||
|
processing_times = []
|
||
|
segmentation_results = []
|
||
|
|
||
|
pbar = tqdm(total=total_frames, desc="处理视频帧")
|
||
|
|
||
|
while cap.isOpened():
|
||
|
start_time = time.time()
|
||
|
|
||
|
ret, frame = cap.read()
|
||
|
if not ret:
|
||
|
break
|
||
|
|
||
|
frame_count += 1
|
||
|
|
||
|
# 图像分割
|
||
|
results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0]
|
||
|
|
||
|
detections = sv.Detections.from_ultralytics(results)
|
||
|
|
||
|
# 应用分割掩码
|
||
|
segmented_frame = mask_annotator.annotate(scene=frame.copy(), detections=detections)
|
||
|
|
||
|
# 记录分割结果
|
||
|
frame_segmentations = []
|
||
|
for i, (bbox, class_id, confidence) in enumerate(
|
||
|
zip(detections.xyxy, detections.class_id, detections.confidence)
|
||
|
):
|
||
|
frame_segmentations.append(
|
||
|
{
|
||
|
"bbox": bbox.tolist(),
|
||
|
"class_id": int(class_id),
|
||
|
"class_name": results.names[int(class_id)],
|
||
|
"confidence": float(confidence),
|
||
|
}
|
||
|
)
|
||
|
|
||
|
segmentation_results.append({"frame": frame_count, "segmentations": frame_segmentations})
|
||
|
|
||
|
# 计算处理时间
|
||
|
end_time = time.time()
|
||
|
processing_time = end_time - start_time
|
||
|
processing_times.append(processing_time)
|
||
|
|
||
|
# 实时显示
|
||
|
if show_live:
|
||
|
cv2.imshow("Segmentation", segmented_frame)
|
||
|
if cv2.waitKey(1) == 27:
|
||
|
break
|
||
|
|
||
|
# 保存帧
|
||
|
if save_annotated and writer:
|
||
|
writer.write(segmented_frame)
|
||
|
|
||
|
pbar.update(1)
|
||
|
|
||
|
# 清理资源
|
||
|
cap.release()
|
||
|
if writer:
|
||
|
writer.release()
|
||
|
if show_live:
|
||
|
cv2.destroyAllWindows()
|
||
|
pbar.close()
|
||
|
|
||
|
# 计算统计信息
|
||
|
avg_time = sum(processing_times) / len(processing_times) if processing_times else 0
|
||
|
fps = 1 / avg_time if avg_time > 0 else 0
|
||
|
|
||
|
return {
|
||
|
"success": True,
|
||
|
"total_frames": frame_count,
|
||
|
"avg_processing_time": avg_time,
|
||
|
"fps": fps,
|
||
|
"segmentation_results": segmentation_results,
|
||
|
"output_video_path": output_video_path if save_annotated else None,
|
||
|
}
|
||
|
|
||
|
def process_image(
|
||
|
self, image_path: str, confidence_threshold: float = 0.5, classes: Optional[List[int]] = None
|
||
|
) -> Dict[str, Any]:
|
||
|
"""
|
||
|
处理单张图像进行图像分割
|
||
|
|
||
|
Args:
|
||
|
image_path: 输入图像路径
|
||
|
confidence_threshold: 置信度阈值
|
||
|
classes: 要检测的类别ID列表
|
||
|
|
||
|
Returns:
|
||
|
分割结果字典
|
||
|
"""
|
||
|
if not os.path.exists(image_path):
|
||
|
raise FileNotFoundError(f"输入图像文件不存在: {image_path}")
|
||
|
|
||
|
# 读取图像
|
||
|
frame = cv2.imread(image_path)
|
||
|
if frame is None:
|
||
|
raise RuntimeError(f"无法读取图像文件: {image_path}")
|
||
|
|
||
|
# 图像分割
|
||
|
results = self.model(frame, conf=confidence_threshold, classes=classes, verbose=False)[0]
|
||
|
|
||
|
detections = sv.Detections.from_ultralytics(results)
|
||
|
|
||
|
# 应用分割掩码
|
||
|
mask_annotator = sv.MaskAnnotator(color=sv.Color(r=0, g=255, b=0), opacity=0.5)
|
||
|
|
||
|
segmented_frame = mask_annotator.annotate(scene=frame.copy(), detections=detections)
|
||
|
|
||
|
# 记录分割结果
|
||
|
segmentation_results = []
|
||
|
for bbox, class_id, confidence in zip(detections.xyxy, detections.class_id, detections.confidence):
|
||
|
segmentation_results.append(
|
||
|
{
|
||
|
"bbox": bbox.tolist(),
|
||
|
"class_id": int(class_id),
|
||
|
"class_name": results.names[int(class_id)],
|
||
|
"confidence": float(confidence),
|
||
|
}
|
||
|
)
|
||
|
|
||
|
return {"success": True, "segmentation_results": segmentation_results, "segmented_image": segmented_frame}
|
||
|
|
||
|
|
||
|
# 示例用法
|
||
|
if __name__ == "__main__":
|
||
|
# 目标检测示例
|
||
|
detector = DetectionProcessor(model_path="yolov11n.pt")
|
||
|
|
||
|
# 定义警告区域
|
||
|
warning_zones = [
|
||
|
[(1153.11, 273.86), (1146.09, 370.77), (1217.71, 352.51), (1220.52, 257.01)],
|
||
|
[(1140.47, 506.99), (1214.90, 504.19), (1212.09, 592.66), (1136.25, 594.07)],
|
||
|
]
|
||
|
|
||
|
# 处理视频
|
||
|
result = detector.process_video(
|
||
|
input_video_path="input_video.mp4",
|
||
|
output_video_path="output_detection.mp4",
|
||
|
confidence_threshold=0.5,
|
||
|
classes=[0], # 只检测人
|
||
|
warning_zones=warning_zones,
|
||
|
save_annotated=True,
|
||
|
)
|
||
|
|
||
|
print("检测结果:", result)
|
||
|
|
||
|
# 图像分割示例
|
||
|
segmenter = SegmentationProcessor(model_path="yolov8n-seg.pt")
|
||
|
|
||
|
# 处理视频
|
||
|
seg_result = segmenter.process_video(
|
||
|
input_video_path="input_video.mp4",
|
||
|
output_video_path="output_segmentation.mp4",
|
||
|
confidence_threshold=0.3,
|
||
|
classes=[0], # 只检测人
|
||
|
save_annotated=True,
|
||
|
)
|
||
|
|
||
|
print("分割结果:", seg_result)
|