197 lines
7.2 KiB
Python
Executable File
197 lines
7.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import os
|
|
import sys
|
|
import cv2
|
|
import argparse
|
|
import time
|
|
from tqdm import tqdm
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib as mpl
|
|
|
|
# 设置中文字体支持
|
|
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei', 'sans-serif']
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
# 添加项目根目录到路径
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
project_root = os.path.dirname(os.path.dirname(current_dir))
|
|
sys.path.append(project_root)
|
|
|
|
from utils.decode_arrow import detect_arrow_direction, visualize_arrow_detection
|
|
|
|
def batch_test_arrows(data_dir="res/arrows", save_dir="res/arrows/test", show_results=False):
|
|
"""
|
|
批量测试箭头方向检测算法
|
|
|
|
参数:
|
|
data_dir: 包含箭头图像的目录
|
|
save_dir: 保存结果的目录
|
|
show_results: 是否显示结果
|
|
|
|
返回:
|
|
results_df: 包含测试结果的DataFrame
|
|
"""
|
|
# 确保保存目录存在
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
# 保存结果的列表
|
|
results = []
|
|
|
|
# 处理左右箭头子目录
|
|
for direction in ["left", "right"]:
|
|
dir_path = os.path.join(data_dir, direction)
|
|
if not os.path.exists(dir_path):
|
|
print(f"警告: 目录 '{dir_path}' 不存在")
|
|
continue
|
|
|
|
# 获取该方向的所有图像文件
|
|
image_files = [f for f in os.listdir(dir_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
|
|
|
|
print(f"处理 {direction} 方向的 {len(image_files)} 个图像...")
|
|
|
|
# 处理每个图像
|
|
for img_file in tqdm(image_files):
|
|
img_path = os.path.join(dir_path, img_file)
|
|
|
|
# 读取图像
|
|
img = cv2.imread(img_path)
|
|
if img is None:
|
|
print(f"错误: 无法加载图像 '{img_path}'")
|
|
continue
|
|
|
|
# 开始计时
|
|
start_time = time.time()
|
|
|
|
# 检测箭头方向
|
|
detected_direction = detect_arrow_direction(img)
|
|
|
|
# 结束计时
|
|
end_time = time.time()
|
|
processing_time = end_time - start_time
|
|
|
|
# 确定检测是否正确
|
|
is_correct = detected_direction == direction
|
|
|
|
# 保存可视化结果
|
|
result_filename = f"{direction}_{img_file.split('.')[0]}_result.jpg"
|
|
result_path = os.path.join(save_dir, result_filename)
|
|
# visualize_arrow_detection(img, result_path)
|
|
|
|
# 保存结果
|
|
results.append({
|
|
"图像文件": img_file,
|
|
"真实方向": direction,
|
|
"检测方向": detected_direction,
|
|
"是否正确": is_correct,
|
|
"处理时间(秒)": processing_time,
|
|
"结果文件": result_filename
|
|
})
|
|
|
|
# 创建结果DataFrame
|
|
results_df = pd.DataFrame(results)
|
|
|
|
# 保存结果到CSV
|
|
csv_path = os.path.join(save_dir, "arrow_detection_results.csv")
|
|
results_df.to_csv(csv_path, index=False, encoding='utf-8-sig')
|
|
|
|
# 生成统计报告
|
|
generate_report(results_df, save_dir)
|
|
|
|
return results_df
|
|
|
|
def generate_report(results_df, save_dir):
|
|
"""生成统计报告和可视化"""
|
|
# 计算总体准确率
|
|
accuracy = results_df["是否正确"].mean() * 100
|
|
|
|
# 按箭头方向分组计算准确率
|
|
direction_accuracy = results_df.groupby("真实方向")["是否正确"].mean() * 100
|
|
|
|
# 计算平均处理时间
|
|
avg_time = results_df["处理时间(秒)"].mean() * 1000 # 转换为毫秒
|
|
|
|
# 创建报告文件
|
|
report_path = os.path.join(save_dir, "arrow_detection_report.txt")
|
|
|
|
with open(report_path, "w", encoding="utf-8") as f:
|
|
f.write("箭头方向检测 - 测试报告\n")
|
|
f.write("=======================\n\n")
|
|
f.write(f"测试日期: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
f.write(f"测试图像总数: {len(results_df)}\n\n")
|
|
f.write(f"总体准确率: {accuracy:.2f}%\n")
|
|
f.write("各方向准确率:\n")
|
|
for direction, acc in direction_accuracy.items():
|
|
f.write(f" - {direction}: {acc:.2f}%\n")
|
|
f.write(f"\n平均处理时间: {avg_time:.2f} 毫秒\n\n")
|
|
|
|
# 错误案例分析
|
|
if not results_df["是否正确"].all():
|
|
f.write("错误检测案例:\n")
|
|
error_cases = results_df[~results_df["是否正确"]]
|
|
for _, row in error_cases.iterrows():
|
|
f.write(f" - 文件: {row['图像文件']}, 真实方向: {row['真实方向']}, 错误检测为: {row['检测方向']}\n")
|
|
|
|
# 创建可视化图表
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
# 准确率条形图
|
|
plt.subplot(1, 2, 1)
|
|
# 将中文索引转为英文避免字体问题
|
|
direction_accuracy_en = direction_accuracy.copy()
|
|
direction_accuracy_en.index = direction_accuracy.index.map(lambda x: "Left" if x == "left" else "Right")
|
|
direction_accuracy_en.plot(kind='bar', color=['blue', 'green'])
|
|
plt.title('各方向检测准确率')
|
|
plt.ylabel('准确率 (%)')
|
|
plt.ylim(0, 100)
|
|
plt.grid(True, linestyle='--', alpha=0.7)
|
|
|
|
# 处理时间箱线图
|
|
plt.subplot(1, 2, 2)
|
|
# 将中文列名改为英文再制图,避免字体问题
|
|
temp_df = results_df.copy()
|
|
temp_df.rename(columns={"处理时间(秒)": "processing_time", "真实方向": "direction"}, inplace=True)
|
|
temp_df.boxplot(column=['processing_time'], by='direction')
|
|
plt.title('处理时间分布')
|
|
plt.ylabel('时间 (秒)')
|
|
plt.suptitle('')
|
|
|
|
# 保存图表
|
|
plt.tight_layout()
|
|
plt.savefig(os.path.join(save_dir, "arrow_detection_stats.png"))
|
|
|
|
print(f"测试报告已保存到: {report_path}")
|
|
print(f"统计图表已保存到: {os.path.join(save_dir, 'arrow_detection_stats.png')}")
|
|
|
|
def main():
|
|
# 创建参数解析器
|
|
parser = argparse.ArgumentParser(description='箭头方向检测批量测试')
|
|
parser.add_argument('--data-dir', default="res/arrows",
|
|
help='箭头图像数据目录 (默认: res/arrows)')
|
|
parser.add_argument('--save-dir', default="res/arrows/test",
|
|
help='保存结果的目录 (默认: res/arrows/test)')
|
|
parser.add_argument('--show', action='store_true',
|
|
help='显示结果图像')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# 运行批量测试
|
|
results = batch_test_arrows(args.data_dir, args.save_dir, args.show)
|
|
|
|
# 输出总体结果
|
|
correct = results["是否正确"].sum()
|
|
total = len(results)
|
|
print(f"\n测试完成! 总共测试了 {total} 张图像,正确检测了 {correct} 张")
|
|
print(f"总体准确率: {(correct/total*100):.2f}%")
|
|
|
|
# 按真实方向打印准确率
|
|
for direction in ["left", "right"]:
|
|
dir_results = results[results["真实方向"] == direction]
|
|
if len(dir_results) > 0:
|
|
dir_correct = dir_results["是否正确"].sum()
|
|
dir_total = len(dir_results)
|
|
print(f"{direction} 方向准确率: {(dir_correct/dir_total*100):.2f}% ({dir_correct}/{dir_total})")
|
|
|
|
if __name__ == "__main__":
|
|
main() |