mi-task/test/test_track_detection.py

138 lines
4.1 KiB
Python
Raw Normal View History

import cv2
import os
import sys
import numpy as np
# 添加正确的项目路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir))
if os.path.exists(os.path.join(project_root, "task")):
sys.path.append(project_root)
else:
# 如果已经在task目录中
parent_dir = os.path.dirname(project_root)
sys.path.append(parent_dir)
# 创建简化的日志函数以避免依赖
def info(msg, tag="INFO"):
print(f"[{tag}] {msg}")
def debug(msg, tag="DEBUG"):
print(f"[{tag}] {msg}")
def warning(msg, tag="WARNING"):
print(f"[{tag}] {msg}")
def error(msg, tag="ERROR"):
print(f"[{tag}] {msg}")
def success(msg, tag="SUCCESS"):
print(f"[{tag}] {msg}")
def timing(msg, tag="TIMING"):
print(f"[{tag}] {msg}")
def section(msg):
print(f"\n{'='*10} {msg} {'='*10}\n")
# 修补utils.log_helper模块
sys.modules['utils.log_helper'] = type('MockLogHelper', (), {
'LogHelper': type('MockLogHelperClass', (), {}),
'get_logger': lambda: None,
'info': info,
'debug': debug,
'warning': warning,
'error': error,
'success': success,
'timing': timing,
'section': section
})
# 直接导入检测函数
try:
from task.utils.detect_dual_track_lines import detect_dual_track_lines
except ImportError:
# 如果task.utils导入失败尝试直接从utils导入
sys.path.append(os.path.join(project_root, "task"))
from utils.detect_dual_track_lines import detect_dual_track_lines
def test_track_detection(image_path, observe=True, save_log=True):
"""
测试轨道线检测函数
参数:
image_path: 图像路径
observe: 是否显示可视化结果
save_log: 是否保存日志
"""
print(f"测试图像: {os.path.basename(image_path)}")
# 默认模式检测
print("- 默认模式检测:")
center_info, left_track, right_track = detect_dual_track_lines(
image_path,
observe=observe,
delay=0 if observe else 1, # 如果观察模式等待按键否则1ms
save_log=save_log,
stone_path_mode=False
)
if center_info:
print(f" 中心点: {center_info['point']}")
print(f" 偏差: {center_info['deviation']:.2f}")
print(f" 轨道宽度: {center_info['track_width']:.2f}")
else:
print(" 检测失败")
# 石板路模式检测
print("- 石板路模式检测:")
center_info, left_track, right_track = detect_dual_track_lines(
image_path,
observe=observe,
delay=0 if observe else 1,
save_log=save_log,
stone_path_mode=True
)
if center_info:
print(f" 中心点: {center_info['point']}")
print(f" 偏差: {center_info['deviation']:.2f}")
print(f" 轨道宽度: {center_info['track_width']:.2f}")
else:
print(" 检测失败")
print("\n" + "-"*50 + "\n")
if __name__ == "__main__":
# 设置观察模式和保存日志
observe = True
save_log = True
if len(sys.argv) > 1:
# 如果提供了图像路径,只测试该图像
image_path = sys.argv[1]
test_track_detection(image_path, observe, save_log)
else:
# 否则测试logs/image目录下的所有原始图像
image_dir = "logs/image"
if not os.path.exists(image_dir):
image_dir = "task/logs/image"
orig_images = [f for f in os.listdir(image_dir) if f.startswith("dual_track_orig_")]
if not orig_images:
print(f"未找到原始图像文件在 {image_dir} 目录下")
sys.exit(1)
for img_file in orig_images:
img_path = os.path.join(image_dir, img_file)
test_track_detection(img_path, observe, save_log)
# 如果是观察模式,每张图片测试后等待按键
if observe:
print("按任意键继续下一张图片测试...")
cv2.waitKey(0)
# 如果有打开的窗口,关闭它们
if observe:
cv2.destroyAllWindows()