178 lines
7.0 KiB
Python
178 lines
7.0 KiB
Python
import cv2
|
||
import os
|
||
import sys
|
||
import time
|
||
import argparse
|
||
|
||
# 添加父目录到系统路径
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
project_root = os.path.dirname(os.path.dirname(current_dir))
|
||
sys.path.append(project_root)
|
||
|
||
from utils.detect_track import detect_yellow_track, visualize_track_detection, detect_horizontal_track_edge, visualize_horizontal_track_edge
|
||
|
||
def process_image(image_path, save_dir=None, show_steps=False):
|
||
"""处理单张图像"""
|
||
print(f"处理图像: {image_path}")
|
||
|
||
# 检测赛道并估算距离
|
||
start_time = time.time()
|
||
edge_point, edge_info = detect_horizontal_track_edge(image_path, observe=show_steps)
|
||
processing_time = time.time() - start_time
|
||
|
||
# 输出结果
|
||
if edge_point is not None and edge_info is not None:
|
||
print(f"处理时间: {processing_time:.3f}秒")
|
||
print(f"边缘点: ({edge_point[0]}, {edge_point[1]})")
|
||
print(f"到中线距离: {edge_info['distance_to_center']}像素")
|
||
print(f"边缘斜率: {edge_info['slope']:.4f}")
|
||
print(f"是否水平: {edge_info['is_horizontal']}")
|
||
print(f"点数量: {edge_info['points_count']}")
|
||
print(f"中线交点: ({edge_info['intersection_point'][0]}, {edge_info['intersection_point'][1]})")
|
||
print(f"交点到底部距离: {edge_info['distance_to_bottom']:.1f}像素")
|
||
print(f"注意: 中线交点是垂直中线与边缘横线的交点")
|
||
print("-" * 30)
|
||
|
||
# 如果指定了保存目录,保存结果
|
||
if save_dir:
|
||
if not os.path.exists(save_dir):
|
||
os.makedirs(save_dir)
|
||
|
||
base_name = os.path.basename(image_path)
|
||
save_path = os.path.join(save_dir, f"result_{base_name}")
|
||
|
||
# 可视化并保存
|
||
visualize_horizontal_track_edge(image_path, save_path=save_path, observe=show_steps)
|
||
print(f"结果已保存到: {save_path}")
|
||
else:
|
||
print("未能检测到黄色赛道")
|
||
|
||
return edge_point, edge_info
|
||
|
||
def process_video(video_path, save_path=None, show_output=True):
|
||
"""处理视频"""
|
||
print(f"处理视频: {video_path}")
|
||
|
||
# 打开视频文件
|
||
cap = cv2.VideoCapture(video_path)
|
||
if not cap.isOpened():
|
||
print("无法打开视频文件")
|
||
return
|
||
|
||
# 获取视频基本信息
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
||
print(f"视频信息: {width}x{height}, {fps} FPS, 总帧数: {total_frames}")
|
||
|
||
# 如果需要保存,创建VideoWriter
|
||
if save_path:
|
||
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
||
out = cv2.VideoWriter(save_path, fourcc, fps, (width, height))
|
||
|
||
# 处理计数器
|
||
frame_count = 0
|
||
processed_count = 0
|
||
start_time = time.time()
|
||
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
frame_count += 1
|
||
|
||
# 每5帧处理一次以提高性能
|
||
if frame_count % 5 == 0:
|
||
processed_count += 1
|
||
|
||
# 检测赛道
|
||
edge_point, edge_info = detect_horizontal_track_edge(frame, observe=False)
|
||
|
||
# 创建结果图像(使用visualize_horizontal_track_edge函数)
|
||
if edge_point is not None and edge_info is not None:
|
||
result_frame = visualize_horizontal_track_edge(frame, observe=False)
|
||
else:
|
||
# 如果未检测到赛道,显示原始帧并添加警告
|
||
result_frame = frame.copy()
|
||
cv2.putText(result_frame, "未检测到横向赛道", (width//4, height//2),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
|
||
|
||
# 添加帧计数
|
||
cv2.putText(result_frame, f"帧: {frame_count}/{total_frames}",
|
||
(width - 200, height - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||
|
||
# 保存或显示结果
|
||
if save_path:
|
||
out.write(result_frame)
|
||
|
||
if show_output:
|
||
# 将交点信息显示在窗口标题上
|
||
if edge_point is not None and edge_info is not None:
|
||
intersection_point = edge_info['intersection_point']
|
||
distance_to_bottom = edge_info['distance_to_bottom']
|
||
title = f'赛道检测 - 交点:({intersection_point[0]},{intersection_point[1]}) 底距:{distance_to_bottom:.1f}px'
|
||
else:
|
||
title = '赛道检测 - 未检测到交点'
|
||
|
||
cv2.imshow(title, result_frame)
|
||
key = cv2.waitKey(1) & 0xFF
|
||
if key == 27: # ESC键退出
|
||
break
|
||
|
||
# 每秒更新一次处理进度
|
||
if frame_count % int(fps) == 0:
|
||
elapsed = time.time() - start_time
|
||
percent = frame_count / total_frames * 100
|
||
print(f"进度: {percent:.1f}% ({frame_count}/{total_frames}), 已用时间: {elapsed:.1f}秒")
|
||
|
||
# 清理
|
||
cap.release()
|
||
if save_path:
|
||
out.release()
|
||
cv2.destroyAllWindows()
|
||
|
||
# 输出统计信息
|
||
total_time = time.time() - start_time
|
||
print(f"视频处理完成,总时间: {total_time:.2f}秒")
|
||
print(f"实际处理帧数: {processed_count}/{frame_count}")
|
||
if processed_count > 0:
|
||
print(f"平均每帧处理时间: {total_time/processed_count:.3f}秒")
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='黄色赛道检测演示程序')
|
||
parser.add_argument('--input', type=str, default='res/path/image_20250513_162556.png', help='输入图像或视频的路径')
|
||
parser.add_argument('--output', type=str, default='res/path/test/result_image_20250513_162556.png', help='输出结果的保存路径')
|
||
parser.add_argument('--type', type=str, choices=['image', 'video'], help='输入类型,不指定会自动检测')
|
||
parser.add_argument('--show', default=True, action='store_true', help='显示处理步骤')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 检查输入路径
|
||
if not os.path.exists(args.input):
|
||
print(f"错误:文件 '{args.input}' 不存在")
|
||
return
|
||
|
||
# 如果未指定类型,根据文件扩展名判断
|
||
if args.type is None:
|
||
ext = os.path.splitext(args.input)[1].lower()
|
||
if ext in ['.jpg', '.jpeg', '.png', '.bmp']:
|
||
args.type = 'image'
|
||
elif ext in ['.mp4', '.avi', '.mov']:
|
||
args.type = 'video'
|
||
else:
|
||
print(f"错误:无法确定文件类型 '{ext}'")
|
||
return
|
||
|
||
# 根据类型处理
|
||
if args.type == 'image':
|
||
# 获取输出目录
|
||
output_dir = os.path.dirname(args.output)
|
||
process_image(args.input, output_dir, args.show)
|
||
else: # video
|
||
process_video(args.input, args.output, args.show)
|
||
|
||
if __name__ == "__main__":
|
||
main() |