mi-task/test/task-path-track/yellow_track_demo.py

178 lines
7.0 KiB
Python
Raw Normal View History

2025-05-14 12:42:01 +08:00
import cv2
import os
import sys
import time
import argparse
# 添加父目录到系统路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir))
sys.path.append(project_root)
from utils.detect_track import detect_yellow_track, visualize_track_detection, detect_horizontal_track_edge, visualize_horizontal_track_edge
def process_image(image_path, save_dir=None, show_steps=False):
"""处理单张图像"""
print(f"处理图像: {image_path}")
# 检测赛道并估算距离
start_time = time.time()
edge_point, edge_info = detect_horizontal_track_edge(image_path, observe=show_steps)
processing_time = time.time() - start_time
# 输出结果
if edge_point is not None and edge_info is not None:
print(f"处理时间: {processing_time:.3f}")
print(f"边缘点: ({edge_point[0]}, {edge_point[1]})")
print(f"到中线距离: {edge_info['distance_to_center']}像素")
print(f"边缘斜率: {edge_info['slope']:.4f}")
print(f"是否水平: {edge_info['is_horizontal']}")
print(f"点数量: {edge_info['points_count']}")
print(f"中线交点: ({edge_info['intersection_point'][0]}, {edge_info['intersection_point'][1]})")
print(f"交点到底部距离: {edge_info['distance_to_bottom']:.1f}像素")
print(f"注意: 中线交点是垂直中线与边缘横线的交点")
print("-" * 30)
# 如果指定了保存目录,保存结果
if save_dir:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
base_name = os.path.basename(image_path)
save_path = os.path.join(save_dir, f"result_{base_name}")
# 可视化并保存
visualize_horizontal_track_edge(image_path, save_path=save_path, observe=show_steps)
print(f"结果已保存到: {save_path}")
else:
print("未能检测到黄色赛道")
return edge_point, edge_info
def process_video(video_path, save_path=None, show_output=True):
"""处理视频"""
print(f"处理视频: {video_path}")
# 打开视频文件
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("无法打开视频文件")
return
# 获取视频基本信息
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"视频信息: {width}x{height}, {fps} FPS, 总帧数: {total_frames}")
# 如果需要保存创建VideoWriter
if save_path:
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(save_path, fourcc, fps, (width, height))
# 处理计数器
frame_count = 0
processed_count = 0
start_time = time.time()
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# 每5帧处理一次以提高性能
if frame_count % 5 == 0:
processed_count += 1
# 检测赛道
edge_point, edge_info = detect_horizontal_track_edge(frame, observe=False)
# 创建结果图像使用visualize_horizontal_track_edge函数
if edge_point is not None and edge_info is not None:
result_frame = visualize_horizontal_track_edge(frame, observe=False)
else:
# 如果未检测到赛道,显示原始帧并添加警告
result_frame = frame.copy()
cv2.putText(result_frame, "未检测到横向赛道", (width//4, height//2),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
# 添加帧计数
cv2.putText(result_frame, f"帧: {frame_count}/{total_frames}",
(width - 200, height - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# 保存或显示结果
if save_path:
out.write(result_frame)
if show_output:
# 将交点信息显示在窗口标题上
if edge_point is not None and edge_info is not None:
intersection_point = edge_info['intersection_point']
distance_to_bottom = edge_info['distance_to_bottom']
title = f'赛道检测 - 交点:({intersection_point[0]},{intersection_point[1]}) 底距:{distance_to_bottom:.1f}px'
else:
title = '赛道检测 - 未检测到交点'
cv2.imshow(title, result_frame)
key = cv2.waitKey(1) & 0xFF
if key == 27: # ESC键退出
break
# 每秒更新一次处理进度
if frame_count % int(fps) == 0:
elapsed = time.time() - start_time
percent = frame_count / total_frames * 100
print(f"进度: {percent:.1f}% ({frame_count}/{total_frames}), 已用时间: {elapsed:.1f}")
# 清理
cap.release()
if save_path:
out.release()
cv2.destroyAllWindows()
# 输出统计信息
total_time = time.time() - start_time
print(f"视频处理完成,总时间: {total_time:.2f}")
print(f"实际处理帧数: {processed_count}/{frame_count}")
if processed_count > 0:
print(f"平均每帧处理时间: {total_time/processed_count:.3f}")
def main():
parser = argparse.ArgumentParser(description='黄色赛道检测演示程序')
parser.add_argument('--input', type=str, default='res/path/image_20250513_162556.png', help='输入图像或视频的路径')
parser.add_argument('--output', type=str, default='res/path/test/result_image_20250513_162556.png', help='输出结果的保存路径')
parser.add_argument('--type', type=str, choices=['image', 'video'], help='输入类型,不指定会自动检测')
parser.add_argument('--show', default=True, action='store_true', help='显示处理步骤')
args = parser.parse_args()
# 检查输入路径
if not os.path.exists(args.input):
print(f"错误:文件 '{args.input}' 不存在")
return
# 如果未指定类型,根据文件扩展名判断
if args.type is None:
ext = os.path.splitext(args.input)[1].lower()
if ext in ['.jpg', '.jpeg', '.png', '.bmp']:
args.type = 'image'
elif ext in ['.mp4', '.avi', '.mov']:
args.type = 'video'
else:
print(f"错误:无法确定文件类型 '{ext}'")
return
# 根据类型处理
if args.type == 'image':
# 获取输出目录
output_dir = os.path.dirname(args.output)
process_image(args.input, output_dir, args.show)
else: # video
process_video(args.input, args.output, args.show)
if __name__ == "__main__":
main()