mi-task/test/test_detect_track.py

159 lines
5.6 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import os
import cv2
import numpy as np
import time
import matplotlib.pyplot as plt
import argparse
from matplotlib.patches import Rectangle
# 添加项目根目录到 Python 路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
sys.path.append(project_root)
# 导入待测试模块
from utils.detect_track import detect_dual_track_lines
def display_image_with_results(image, center_info, left_info, right_info, save_path=None):
"""
显示检测结果并标记轨道线
"""
# 创建图像副本
result_img = image.copy()
height, width = image.shape[:2]
center_x = width // 2
# 如果检测到轨道线
if center_info is not None and left_info is not None and right_info is not None:
# 获取线段坐标
left_x1, left_y1, left_x2, left_y2 = left_info["line"]
right_x1, right_y1, right_x2, right_y2 = right_info["line"]
# 计算中心线坐标
center_line_x1 = (left_x1 + right_x1) // 2
center_line_y1 = (left_y1 + right_y1) // 2
center_line_x2 = (left_x2 + right_x2) // 2
center_line_y2 = (left_y2 + right_y2) // 2
# 绘制左右轨道线
cv2.line(result_img, (left_x1, left_y1), (left_x2, left_y2), (255, 0, 0), 3) # 蓝色
cv2.line(result_img, (right_x1, right_y1), (right_x2, right_y2), (0, 0, 255), 3) # 红色
# 绘制中心线
cv2.line(result_img, (center_line_x1, center_line_y1), (center_line_x2, center_line_y2), (0, 255, 0), 3) # 绿色
# 绘制中线与图像底部的交点
bottom_x = center_info["point"][0]
cv2.circle(result_img, (bottom_x, height), 12, (255, 0, 255), -1) # 紫色
# 绘制图像中心线
cv2.line(result_img, (center_x, 0), (center_x, height), (0, 0, 255), 1) # 红色
# 计算并显示偏差
deviation = center_info["deviation"]
track_width = center_info["track_width"]
# 添加文本信息
cv2.putText(result_img, f"偏差: {deviation}px", (20, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
cv2.putText(result_img, f"轨道宽度: {track_width:.1f}px", (20, 70),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
# 添加方向指示
if deviation > 0:
direction = "向右偏"
elif deviation < 0:
direction = "向左偏"
else:
direction = "居中"
cv2.putText(result_img, f"方向: {direction}", (20, 110),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
else:
# 如果未检测到轨道线,显示错误信息
cv2.putText(result_img, "未检测到双轨道线", (width // 3, height // 2),
cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 3)
# 如果指定了保存路径,保存图像
if save_path:
cv2.imwrite(save_path, result_img)
print(f"结果已保存至: {save_path}")
return result_img
def process_image(image_path, save_dir=None, show_steps=False):
"""
处理单张图像,检测双轨道线
"""
print(f"\n正在处理图像: {image_path}")
# 读取图像
try:
image = cv2.imread(image_path)
if image is None:
print(f"无法读取图像: {image_path}")
return None, None, None
except Exception as e:
print(f"读取图像出错: {e}")
return None, None, None
# 获取图像尺寸
height, width = image.shape[:2]
# 测试双轨道线检测
print("\n执行双轨道线检测...")
start_time = time.time()
center_info, left_info, right_info = detect_dual_track_lines(image, observe=show_steps, save_log=True)
processing_time = time.time() - start_time
print(f"处理时间: {processing_time:.4f}")
# 输出结果
if center_info is not None:
print(f"检测成功 - 中心偏差: {center_info['deviation']}像素")
print(f"轨道宽度: {center_info['track_width']:.1f}像素")
print(f"中心线斜率: {center_info['slope']:.4f}")
print(f"是否垂直: {'' if center_info['is_vertical'] else ''}")
print("-" * 30)
else:
print("未检测到双轨道线")
# 如果指定了保存目录,保存结果
if save_dir and center_info is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 生成结果文件名
base_name = os.path.basename(image_path)
result_name = f"result_{base_name}"
result_path = os.path.join(save_dir, result_name)
# 显示并保存结果
result_img = display_image_with_results(image, center_info, left_info, right_info, result_path)
return center_info, left_info, right_info
def main():
parser = argparse.ArgumentParser(description='双轨道线检测演示程序')
parser.add_argument('--input', type=str, default='res/path/image_20250514_024347.png', help='输入图像路径')
parser.add_argument('--output', type=str, default='res/path/test', help='输出结果的保存目录')
parser.add_argument('--show', default=True, action='store_true', help='显示处理步骤')
args = parser.parse_args()
# 检查输入路径
if not os.path.exists(args.input):
print(f"错误:文件 '{args.input}' 不存在")
return
# 处理图像
process_image(args.input, args.output, args.show)
print("\n处理完成!")
if __name__ == "__main__":
main()