mi-task/test/task-path-track/yellow_track_demo.py
2025-05-14 12:42:01 +08:00

178 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import os
import sys
import time
import argparse
# 添加父目录到系统路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir))
sys.path.append(project_root)
from utils.detect_track import detect_yellow_track, visualize_track_detection, detect_horizontal_track_edge, visualize_horizontal_track_edge
def process_image(image_path, save_dir=None, show_steps=False):
"""处理单张图像"""
print(f"处理图像: {image_path}")
# 检测赛道并估算距离
start_time = time.time()
edge_point, edge_info = detect_horizontal_track_edge(image_path, observe=show_steps)
processing_time = time.time() - start_time
# 输出结果
if edge_point is not None and edge_info is not None:
print(f"处理时间: {processing_time:.3f}")
print(f"边缘点: ({edge_point[0]}, {edge_point[1]})")
print(f"到中线距离: {edge_info['distance_to_center']}像素")
print(f"边缘斜率: {edge_info['slope']:.4f}")
print(f"是否水平: {edge_info['is_horizontal']}")
print(f"点数量: {edge_info['points_count']}")
print(f"中线交点: ({edge_info['intersection_point'][0]}, {edge_info['intersection_point'][1]})")
print(f"交点到底部距离: {edge_info['distance_to_bottom']:.1f}像素")
print(f"注意: 中线交点是垂直中线与边缘横线的交点")
print("-" * 30)
# 如果指定了保存目录,保存结果
if save_dir:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
base_name = os.path.basename(image_path)
save_path = os.path.join(save_dir, f"result_{base_name}")
# 可视化并保存
visualize_horizontal_track_edge(image_path, save_path=save_path, observe=show_steps)
print(f"结果已保存到: {save_path}")
else:
print("未能检测到黄色赛道")
return edge_point, edge_info
def process_video(video_path, save_path=None, show_output=True):
"""处理视频"""
print(f"处理视频: {video_path}")
# 打开视频文件
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("无法打开视频文件")
return
# 获取视频基本信息
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"视频信息: {width}x{height}, {fps} FPS, 总帧数: {total_frames}")
# 如果需要保存创建VideoWriter
if save_path:
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(save_path, fourcc, fps, (width, height))
# 处理计数器
frame_count = 0
processed_count = 0
start_time = time.time()
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# 每5帧处理一次以提高性能
if frame_count % 5 == 0:
processed_count += 1
# 检测赛道
edge_point, edge_info = detect_horizontal_track_edge(frame, observe=False)
# 创建结果图像使用visualize_horizontal_track_edge函数
if edge_point is not None and edge_info is not None:
result_frame = visualize_horizontal_track_edge(frame, observe=False)
else:
# 如果未检测到赛道,显示原始帧并添加警告
result_frame = frame.copy()
cv2.putText(result_frame, "未检测到横向赛道", (width//4, height//2),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
# 添加帧计数
cv2.putText(result_frame, f"帧: {frame_count}/{total_frames}",
(width - 200, height - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# 保存或显示结果
if save_path:
out.write(result_frame)
if show_output:
# 将交点信息显示在窗口标题上
if edge_point is not None and edge_info is not None:
intersection_point = edge_info['intersection_point']
distance_to_bottom = edge_info['distance_to_bottom']
title = f'赛道检测 - 交点:({intersection_point[0]},{intersection_point[1]}) 底距:{distance_to_bottom:.1f}px'
else:
title = '赛道检测 - 未检测到交点'
cv2.imshow(title, result_frame)
key = cv2.waitKey(1) & 0xFF
if key == 27: # ESC键退出
break
# 每秒更新一次处理进度
if frame_count % int(fps) == 0:
elapsed = time.time() - start_time
percent = frame_count / total_frames * 100
print(f"进度: {percent:.1f}% ({frame_count}/{total_frames}), 已用时间: {elapsed:.1f}")
# 清理
cap.release()
if save_path:
out.release()
cv2.destroyAllWindows()
# 输出统计信息
total_time = time.time() - start_time
print(f"视频处理完成,总时间: {total_time:.2f}")
print(f"实际处理帧数: {processed_count}/{frame_count}")
if processed_count > 0:
print(f"平均每帧处理时间: {total_time/processed_count:.3f}")
def main():
parser = argparse.ArgumentParser(description='黄色赛道检测演示程序')
parser.add_argument('--input', type=str, default='res/path/image_20250513_162556.png', help='输入图像或视频的路径')
parser.add_argument('--output', type=str, default='res/path/test/result_image_20250513_162556.png', help='输出结果的保存路径')
parser.add_argument('--type', type=str, choices=['image', 'video'], help='输入类型,不指定会自动检测')
parser.add_argument('--show', default=True, action='store_true', help='显示处理步骤')
args = parser.parse_args()
# 检查输入路径
if not os.path.exists(args.input):
print(f"错误:文件 '{args.input}' 不存在")
return
# 如果未指定类型,根据文件扩展名判断
if args.type is None:
ext = os.path.splitext(args.input)[1].lower()
if ext in ['.jpg', '.jpeg', '.png', '.bmp']:
args.type = 'image'
elif ext in ['.mp4', '.avi', '.mov']:
args.type = 'video'
else:
print(f"错误:无法确定文件类型 '{ext}'")
return
# 根据类型处理
if args.type == 'image':
# 获取输出目录
output_dir = os.path.dirname(args.output)
process_image(args.input, output_dir, args.show)
else: # video
process_video(args.input, args.output, args.show)
if __name__ == "__main__":
main()