mi-task/utils/detect_furthest_intersection.py
2025-08-18 11:06:42 +08:00

592 lines
25 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import os
import datetime
from sklearn import linear_model
from utils.log_helper import get_logger, debug, info, warning, error, success
def detect_furthest_horizontal_intersection(image, observe=False, delay=1000, save_log=True):
"""
检测正前方x轴中间线与最远横向黄色赛道线的交点
参数:
image: 输入图像,可以是文件路径或者已加载的图像数组
observe: 是否输出中间状态信息和可视化结果默认为False
delay: 展示每个步骤的等待时间(毫秒)
save_log: 是否保存日志和图像
返回:
intersection_point: x轴中线与最远横线的交点坐标 (x, y)
intersection_info: 交点信息字典
"""
observe = False
# 如果输入是字符串(文件路径),则加载图像
if isinstance(image, str):
img = cv2.imread(image)
else:
img = image.copy()
if img is None:
error("无法加载图像", "失败")
return None, None
if save_log:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
origin_image_path = os.path.join("logs/image", f"origin_intersection_{timestamp}.jpg")
os.makedirs("logs/image", exist_ok=True)
cv2.imwrite(origin_image_path, img)
info(f"保存原始图像到: {origin_image_path}", "日志")
# 获取图像尺寸
height, width = img.shape[:2]
# 计算图像中间区域的范围(用于专注于正前方的赛道)
center_x = width // 2
search_width = int(width * 0.8) # 扩大搜索区域宽度为图像宽度的80%
search_height = height # 搜索区域高度为图像高度的1/1
left_bound = center_x - search_width // 2
right_bound = center_x + search_width // 2
bottom_bound = height
top_bound = height - search_height
# 定义合理的值范围 - 更宽松的参数以检测更远的横线
valid_y_range = (height * 0.05, height * 0.6) # 扩大有效的y坐标范围
max_slope = 0.3 # 增加允许的最大斜率
min_line_length = width * 0.1 # 减小最小线长度要求
if observe:
debug("步骤1: 原始图像已加载", "加载")
search_region_img = img.copy()
# 绘制搜索区域
cv2.rectangle(search_region_img, (left_bound, top_bound), (right_bound, bottom_bound), (255, 0, 0), 2)
cv2.line(search_region_img, (center_x, 0), (center_x, height), (0, 0, 255), 2) # 中线
cv2.imshow("搜索区域", search_region_img)
cv2.waitKey(delay)
# 图像预处理 - 增强对比度以便更好地提取黄色部分
img_enhanced = img.copy()
# 将图像转换为LAB颜色空间
lab = cv2.cvtColor(img_enhanced, cv2.COLOR_BGR2LAB)
# 分离L通道
l, a, b = cv2.split(lab)
# 应用CLAHE对比度受限自适应直方图均衡化
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
cl = clahe.apply(l)
# 合并通道
limg = cv2.merge((cl, a, b))
# 转回BGR颜色空间
img_enhanced = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
if observe:
debug("步骤1.5: 增强对比度", "处理")
cv2.imshow("增强对比度", img_enhanced)
cv2.waitKey(delay)
# 转换到HSV颜色空间以便更容易提取黄色
hsv = cv2.cvtColor(img_enhanced, cv2.COLOR_BGR2HSV)
# 黄色的HSV范围 - 扩大范围以适应不同光照条件下的黄色
lower_yellow = np.array([15, 70, 70]) # 降低饱和度和亮度阈值
upper_yellow = np.array([35, 255, 255]) # 扩大色调范围
# 创建黄色的掩码
mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
# 添加形态学操作以改善掩码
kernel = np.ones((5, 5), np.uint8) # 增大内核大小
mask = cv2.dilate(mask, kernel, iterations=2) # 增加膨胀次数
mask = cv2.erode(mask, np.ones((3, 3), np.uint8), iterations=1) # 添加腐蚀操作去除噪点
if observe:
debug("步骤2: 创建黄色掩码", "处理")
cv2.imshow("黄色掩码", mask)
cv2.waitKey(delay)
# 应用掩码,只保留黄色部分
yellow_only = cv2.bitwise_and(img_enhanced, img_enhanced, mask=mask)
if observe:
debug("步骤3: 提取黄色部分", "处理")
cv2.imshow("只保留黄色", yellow_only)
cv2.waitKey(delay)
# 裁剪掩码到搜索区域
search_mask = mask[top_bound:bottom_bound, left_bound:right_bound]
# 寻找每列的最顶部点(最远的边缘点)
top_points = []
non_zero_cols = np.where(np.any(search_mask, axis=0))[0]
for col in non_zero_cols:
col_points = np.where(search_mask[:, col] > 0)[0]
if len(col_points) > 0:
top_row = np.min(col_points)
top_points.append((left_bound + col, top_bound + top_row))
if observe and top_points:
debug("检测顶部边缘点", "处理")
edge_points_img = img.copy()
for point in top_points:
cv2.circle(edge_points_img, point, 3, (255, 0, 255), -1)
cv2.imshow("顶部边缘点", edge_points_img)
cv2.waitKey(delay)
# 尝试直接从顶部边缘点拟合直线
if len(top_points) >= 10: # 如果有足够多的顶部边缘点
try:
# 使用RANSAC拟合直线
x_points = np.array([p[0] for p in top_points]).reshape(-1, 1)
y_points = np.array([p[1] for p in top_points])
ransac = linear_model.RANSACRegressor(residual_threshold=5.0)
ransac.fit(x_points, y_points)
# 获取拟合的斜率和截距
direct_fitted_slope = ransac.estimator_.coef_[0]
direct_intercept = ransac.estimator_.intercept_
# 如果斜率在合理范围内
if abs(direct_fitted_slope) < max_slope:
# 计算线段端点
direct_x1 = left_bound
direct_y1 = int(direct_fitted_slope * direct_x1 + direct_intercept)
direct_x2 = right_bound
direct_y2 = int(direct_fitted_slope * direct_x2 + direct_intercept)
# 计算交点
direct_intersection_x = center_x
direct_intersection_y = direct_fitted_slope * (center_x - direct_x1) + direct_y1
direct_intersection_point = (int(direct_intersection_x), int(direct_intersection_y))
# 如果交点在合理范围内
if 0 <= direct_intersection_y < height * 0.7:
# 创建直接拟合的线的信息
direct_fitted_info = {
"x": direct_intersection_point[0],
"y": direct_intersection_point[1],
"distance_to_bottom": height - direct_intersection_y,
"slope": direct_fitted_slope,
"is_horizontal": abs(direct_fitted_slope) < 0.05,
"score": 0.9, # 给予较高的分数
"valid": True,
"fitted_from_edge_points": True
}
if observe:
debug("从顶部边缘点直接拟合出横线", "处理")
direct_fit_img = img.copy()
cv2.line(direct_fit_img, (int(direct_x1), int(direct_y1)),
(int(direct_x2), int(direct_y2)), (0, 255, 0), 2)
cv2.circle(direct_fit_img, direct_intersection_point, 10, (255, 0, 255), -1)
cv2.imshow("从边缘点直接拟合的线", direct_fit_img)
cv2.waitKey(delay)
if save_log:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
direct_fit_path = os.path.join("logs/image", f"direct_fit_{timestamp}.jpg")
direct_fit_img = img.copy()
cv2.line(direct_fit_img, (int(direct_x1), int(direct_y1)),
(int(direct_x2), int(direct_y2)), (0, 255, 0), 2)
cv2.circle(direct_fit_img, direct_intersection_point, 10, (255, 0, 255), -1)
cv2.imwrite(direct_fit_path, direct_fit_img)
info(f"保存从边缘点直接拟合的线到: {direct_fit_path}", "日志")
return direct_intersection_point, direct_fitted_info
except Exception as e:
if observe:
warning(f"从边缘点直接拟合线失败: {str(e)}", "警告")
# 边缘检测 - 使用更适合检测远处横线的参数
edges = cv2.Canny(mask, 30, 120, apertureSize=3) # 降低阈值以检测更多边缘
if observe:
debug("步骤4: 边缘检测", "处理")
cv2.imshow("边缘检测", edges)
cv2.waitKey(delay)
# 使用霍夫变换检测直线,使用更宽松的参数
lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=20, # 降低阈值
minLineLength=width*0.08, maxLineGap=40) # 减少最小长度,增加最大间隙
if lines is None or len(lines) == 0:
# 如果找不到线,尝试使用更宽松的参数
lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=15,
minLineLength=width*0.05, maxLineGap=50)
if lines is None or len(lines) == 0:
if observe:
error("未检测到直线", "失败")
return None, None
if observe:
debug(f"步骤5: 检测到 {len(lines)} 条直线", "处理")
lines_img = img.copy()
for i, line in enumerate(lines):
x1, y1, x2, y2 = line[0]
# 使用HSV颜色空间生成不同的颜色
hue = (i * 30) % 180 # 每30度一个颜色
color = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2BGR)[0][0]
color = (int(color[0]), int(color[1]), int(color[2]))
cv2.line(lines_img, (x1, y1), (x2, y2), color, 2)
cv2.imshow("检测到的直线", lines_img)
cv2.waitKey(delay)
# 过滤和合并相似的线段
filtered_lines = []
for line in lines:
x1, y1, x2, y2 = line[0]
# 确保x1 < x2
if x1 > x2:
x1, x2 = x2, x1
y1, y2 = y2, y1
filtered_lines.append([x1, y1, x2, y2])
# 合并相似线段
merged_lines = []
used_indices = set()
for i, line1 in enumerate(filtered_lines):
if i in used_indices:
continue
x1, y1, x2, y2 = line1
similar_lines = [line1]
used_indices.add(i)
# 查找与当前线段相似的其他线段
for j, line2 in enumerate(filtered_lines):
if j in used_indices or i == j:
continue
x3, y3, x4, y4 = line2
# 计算两条线段的斜率
slope1 = (y2 - y1) / (x2 - x1) if abs(x2 - x1) > 5 else 100
slope2 = (y4 - y3) / (x4 - x3) if abs(x4 - x3) > 5 else 100
# 计算两条线段的中点
mid1_x, mid1_y = (x1 + x2) / 2, (y1 + y2) / 2
mid2_x, mid2_y = (x3 + x4) / 2, (y3 + y4) / 2
# 计算中点之间的距离
mid_dist = np.sqrt((mid2_x - mid1_x)**2 + (mid2_y - mid1_y)**2)
# 计算线段端点之间的最小距离
end_dists = [
np.sqrt((x1-x3)**2 + (y1-y3)**2),
np.sqrt((x1-x4)**2 + (y1-y4)**2),
np.sqrt((x2-x3)**2 + (y2-y3)**2),
np.sqrt((x2-x4)**2 + (y2-y4)**2)
]
min_end_dist = min(end_dists)
# 判断两条线段是否相似:满足以下条件之一
# 1. 斜率接近且中点距离不太远
# 2. 斜率接近且端点之间距离很近(可能是连接的线段)
# 3. 端点非常接近(几乎连接),且斜率差异不太大
if (abs(slope1 - slope2) < 0.15 and mid_dist < height * 0.15) or \
(abs(slope1 - slope2) < 0.1 and min_end_dist < height * 0.05) or \
(min_end_dist < height * 0.03 and abs(slope1 - slope2) < 0.25):
similar_lines.append(line2)
used_indices.add(j)
# 如果找到相似线段,合并它们
if len(similar_lines) > 1:
# 合并所有相似线段的端点
all_points = []
for line in similar_lines:
all_points.append((line[0], line[1])) # 起点
all_points.append((line[2], line[3])) # 终点
# 找出x坐标的最小值和最大值
min_x = min(p[0] for p in all_points)
max_x = max(p[0] for p in all_points)
# 使用所有点拟合一条直线
x_points = np.array([p[0] for p in all_points]).reshape(-1, 1)
y_points = np.array([p[1] for p in all_points])
# 使用RANSAC拟合更稳定的直线
ransac = linear_model.RANSACRegressor(residual_threshold=5.0)
ransac.fit(x_points, y_points)
# 获取拟合的斜率和截距
merged_slope = ransac.estimator_.coef_[0]
merged_intercept = ransac.estimator_.intercept_
# 计算新的端点
y_min = int(merged_slope * min_x + merged_intercept)
y_max = int(merged_slope * max_x + merged_intercept)
# 添加合并后的线段
merged_lines.append([min_x, y_min, max_x, y_max])
else:
# 如果没有相似线段,直接添加原线段
merged_lines.append(line1)
# 将合并后的线段转换为霍夫变换的格式
merged_hough_lines = []
for line in merged_lines:
merged_hough_lines.append(np.array([[line[0], line[1], line[2], line[3]]]))
if observe:
debug(f"步骤5.1: 合并后剩余 {len(merged_hough_lines)} 条线", "处理")
merged_img = img.copy()
for i, line in enumerate(merged_hough_lines):
x1, y1, x2, y2 = line[0]
# 使用HSV颜色空间生成不同的颜色
hue = (i * 50) % 180 # 每50度一个颜色
color = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2BGR)[0][0]
color = (int(color[0]), int(color[1]), int(color[2]))
cv2.line(merged_img, (x1, y1), (x2, y2), color, 3)
cv2.imshow("合并后的线段", merged_img)
cv2.waitKey(delay)
# 使用合并后的线段继续处理
lines = merged_hough_lines
# 筛选水平线
horizontal_lines = []
for line in lines:
x1, y1, x2, y2 = line[0]
# 计算斜率 (避免除零错误)
if abs(x2 - x1) < 5: # 几乎垂直的线
continue
slope = (y2 - y1) / (x2 - x1)
# 筛选接近水平的线 (斜率接近0)
if abs(slope) < max_slope:
# 确保线在搜索区域内
if ((left_bound <= x1 <= right_bound and top_bound <= y1 <= bottom_bound) or
(left_bound <= x2 <= right_bound and top_bound <= y2 <= bottom_bound)):
# 计算线的中点y坐标
mid_y = (y1 + y2) / 2
line_length = np.sqrt((x2-x1)**2 + (y2-y1)**2)
# 优先选择更靠近图像上方的线段y值更小
position_score = 1.0 - (mid_y / height)
# 计算长度得分(越长越好)
length_score = min(1.0, line_length / (width * 0.5))
# 计算斜率得分(越水平越好)
slope_score = max(0.0, 1.0 - abs(slope) / max_slope)
# 计算线段位于图像中央的程度
mid_x = (x1 + x2) / 2
center_score = max(0.0, 1.0 - abs(mid_x - center_x) / (width * 0.4))
# 计算综合得分,优先考虑高位置和水平度
quality_score = position_score * 0.6 + length_score * 0.15 + slope_score * 0.2 + center_score * 0.05
# 保存线段、其y坐标、斜率、长度和质量得分
horizontal_lines.append((line[0], mid_y, slope, line_length, quality_score))
# 如果没有找到水平线,尝试使用顶部边缘点拟合
if not horizontal_lines and len(top_points) >= 5:
if observe:
debug("未检测到水平线,尝试使用顶部边缘点拟合", "处理")
# 筛选上半部分的点
upper_points = [p for p in top_points if p[1] < height * 0.5]
if len(upper_points) >= 5:
try:
# 使用RANSAC拟合直线
x_points = np.array([p[0] for p in upper_points]).reshape(-1, 1)
y_points = np.array([p[1] for p in upper_points])
ransac = linear_model.RANSACRegressor(residual_threshold=8.0) # 增大残差阈值
ransac.fit(x_points, y_points)
# 获取拟合的斜率和截距
fitted_slope = ransac.estimator_.coef_[0]
intercept = ransac.estimator_.intercept_
# 如果斜率在合理范围内
if abs(fitted_slope) < max_slope:
# 计算线段端点
x1 = left_bound
y1 = int(fitted_slope * x1 + intercept)
x2 = right_bound
y2 = int(fitted_slope * x2 + intercept)
# 计算中点y坐标和线长
mid_y = (y1 + y2) / 2
line_length = np.sqrt((x2-x1)**2 + (y2-y1)**2)
# 计算得分
position_score = 1.0 - (mid_y / height)
quality_score = position_score * 0.7 + 0.3 # 边缘点拟合的线给予高分
# 添加到水平线列表
horizontal_lines.append((np.array([x1, y1, x2, y2]), mid_y, fitted_slope, line_length, quality_score))
if observe:
debug(f"从边缘点成功拟合出水平线,斜率: {fitted_slope:.4f}", "处理")
fitted_line_img = img.copy()
cv2.line(fitted_line_img, (x1, y1), (x2, y2), (0, 255, 255), 2)
for point in upper_points:
cv2.circle(fitted_line_img, point, 3, (0, 255, 0), -1)
cv2.imshow("拟合的水平线", fitted_line_img)
cv2.waitKey(delay)
except Exception as e:
if observe:
error(f"拟合水平线失败: {str(e)}", "失败")
if not horizontal_lines:
if observe:
error("未检测到合格的水平线", "失败")
return None, None
# 根据质量得分排序水平线(得分高的排前面)
horizontal_lines.sort(key=lambda x: x[4], reverse=True)
if observe:
debug(f"步骤6: 找到 {len(horizontal_lines)} 条水平线", "处理")
h_lines_img = img.copy()
# 绘制所有水平线
for i, line_info in enumerate(horizontal_lines):
line, mid_y, slope, length, score = line_info
if isinstance(line, np.ndarray) and line.shape[0] == 4:
x1, y1, x2, y2 = line
else:
x1, y1, x2, y2 = line
# 根据得分调整线的颜色,得分越高越绿
color = (int(255 * (1-score)), int(255 * score), 0)
cv2.line(h_lines_img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
# 显示斜率和得分
cv2.putText(h_lines_img, f"{i}:{score:.2f}", ((int(x1)+int(x2))//2, (int(y1)+int(y2))//2),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
cv2.imshow("水平线", h_lines_img)
cv2.waitKey(delay)
# 选择得分最高的线作为最远横线
selected_line = horizontal_lines[0][0]
selected_slope = horizontal_lines[0][2]
selected_score = horizontal_lines[0][4]
# 提取线段端点
if isinstance(selected_line, np.ndarray) and selected_line.shape[0] == 4:
x1, y1, x2, y2 = selected_line
else:
x1, y1, x2, y2 = selected_line
# 确保x1 < x2
if x1 > x2:
x1, x2 = x2, x1
y1, y2 = y2, y1
# 计算中线与检测到的横向线的交点
# 横向线方程: y = slope * (x - x1) + y1
# 中线方程: x = center_x
# 解这个方程组得到交点坐标
intersection_x = center_x
intersection_y = selected_slope * (center_x - x1) + y1
intersection_point = (int(intersection_x), int(intersection_y))
# 计算交点到图像底部的距离(以像素为单位)
distance_to_bottom = height - intersection_y
# 检查交点是否在合理范围内
valid_result = True
reason = ""
if intersection_y < 0:
valid_result = False
reason += "交点y坐标超出图像上边界; "
elif intersection_y > height * 0.95: # 允许交点在靠近底部但不太接近底部的位置
valid_result = False
reason += "交点y坐标过于接近图像底部; "
# 可视化结果
result_img = None
if observe or save_log:
result_img = img.copy()
# 画出检测到的线
line_color = (0, 255, 0) if valid_result else (0, 0, 255)
cv2.line(result_img, (int(x1), int(y1)), (int(x2), int(y2)), line_color, 2)
# 画出中线
cv2.line(result_img, (center_x, 0), (center_x, height), (0, 0, 255), 2)
# 标记中线与横向线的交点
cv2.circle(result_img, intersection_point, 12, (255, 0, 255), -1)
cv2.circle(result_img, intersection_point, 5, (255, 255, 255), -1)
cv2.putText(result_img, f"Slope: {selected_slope:.4f}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, line_color, 2)
cv2.putText(result_img, f"Intersection: ({intersection_point[0]}, {intersection_point[1]})", (10, 70),
cv2.FONT_HERSHEY_SIMPLEX, 1, line_color, 2)
cv2.putText(result_img, f"Distance to bottom: {distance_to_bottom:.1f}px", (10, 110),
cv2.FONT_HERSHEY_SIMPLEX, 1, line_color, 2)
cv2.putText(result_img, f"Score: {selected_score:.2f}", (10, 150),
cv2.FONT_HERSHEY_SIMPLEX, 1, line_color, 2)
if not valid_result:
cv2.putText(result_img, f"Warning: {reason}", (10, 190),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
if observe:
debug("显示交点结果", "显示")
cv2.imshow("交点结果", result_img)
cv2.waitKey(delay)
# 保存日志图像
if save_log and result_img is not None:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
log_dir = "logs/image"
os.makedirs(log_dir, exist_ok=True)
# 保存目标横线图像
img_path = os.path.join(log_dir, f"target_line_{timestamp}.jpg")
cv2.imwrite(img_path, result_img)
info(f"保存目标横线检测结果图像到: {img_path}", "日志")
# 保存文本日志信息
log_info = {
"timestamp": timestamp,
"intersection_point": intersection_point,
"distance_to_bottom": distance_to_bottom,
"slope": selected_slope,
"score": selected_score,
"valid": valid_result,
"reason": reason if not valid_result else ""
}
info(f"最远交点检测结果: {log_info}", "日志")
else:
info("未保存日志图像", "日志")
# 即使结果无效也返回,方便调试
# 创建交点信息字典
intersection_info = {
"x": intersection_point[0],
"y": intersection_point[1],
"distance_to_bottom": distance_to_bottom,
"slope": selected_slope,
"is_horizontal": abs(selected_slope) < 0.05, # 判断是否接近水平
"score": selected_score, # 线段质量得分
"valid": valid_result, # 添加有效性标志
"reason": reason if not valid_result else "" # 添加无效原因
}
# 即使交点可能无效,也返回计算结果,由调用者决定是否使用
return intersection_point, intersection_info
# 测试代码,仅在直接运行该文件时执行
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
image_path = sys.argv[1]
else:
image_path = "res/path/task-5/origin_horizontal_edge_20250528_100447_858352.jpg" # 默认测试图像
intersection_point, intersection_info = detect_furthest_horizontal_intersection(
image_path, observe=True, delay=800)
if intersection_point is not None:
print(f"检测到的最远交点: {intersection_point}")
print(f"交点信息: {intersection_info}")
else:
print("未检测到有效的最远交点")