mi-task/utils/detect_track.py
2025-05-15 12:08:49 +00:00

451 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn import linear_model
def detect_horizontal_track_edge(image, observe=False, delay=1000):
observe = False # TSET
"""
检测正前方横向黄色赛道的边缘并返回y值最大的边缘点
参数:
image: 输入图像,可以是文件路径或者已加载的图像数组
observe: 是否输出中间状态信息和可视化结果默认为False
delay: 展示每个步骤的等待时间(毫秒)
返回:
edge_point: 赛道前方边缘点的坐标 (x, y)
edge_info: 边缘信息字典
"""
# 如果输入是字符串(文件路径),则加载图像
if isinstance(image, str):
img = cv2.imread(image)
else:
img = image.copy()
if img is None:
print("无法加载图像")
return None, None
# 获取图像尺寸
height, width = img.shape[:2]
# 计算图像中间区域的范围(用于专注于正前方的赛道)
center_x = width // 2
search_width = int(width * 2/3) # 搜索区域宽度为图像宽度的2/3
search_height = height # 搜索区域高度为图像高度的1/1
left_bound = center_x - search_width // 2
right_bound = center_x + search_width // 2
bottom_bound = height
top_bound = height - search_height
if observe:
print("步骤1: 原始图像已加载")
search_region_img = img.copy()
# 绘制搜索区域
cv2.rectangle(search_region_img, (left_bound, top_bound), (right_bound, bottom_bound), (255, 0, 0), 2)
cv2.line(search_region_img, (center_x, 0), (center_x, height), (0, 0, 255), 2) # 中线
cv2.imshow("搜索区域", search_region_img)
cv2.waitKey(delay)
# 转换到HSV颜色空间以便更容易提取黄色
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# 黄色的HSV范围
lower_yellow = np.array([20, 100, 100])
upper_yellow = np.array([30, 255, 255])
# 创建黄色的掩码
mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
if observe:
print("步骤2: 创建黄色掩码")
cv2.imshow("黄色掩码", mask)
cv2.waitKey(delay)
# 使用形态学操作改善掩码质量
kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # 闭操作填充小空洞
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # 开操作移除小噪点
if observe:
print("步骤2.1: 形态学处理后的掩码")
cv2.imshow("处理后的掩码", mask)
cv2.waitKey(delay)
# 应用掩码,只保留黄色部分
yellow_only = cv2.bitwise_and(img, img, mask=mask)
if observe:
print("步骤3: 提取黄色部分")
cv2.imshow("只保留黄色", yellow_only)
cv2.waitKey(delay)
# 查找轮廓
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 如果没有找到轮廓返回None
if not contours:
if observe:
print("未找到轮廓")
return None, None
if observe:
print(f"步骤4: 找到 {len(contours)} 个轮廓")
contour_img = img.copy()
cv2.drawContours(contour_img, contours, -1, (0, 255, 0), 2)
cv2.imshow("所有轮廓", contour_img)
cv2.waitKey(delay)
# 筛选可能属于横向赛道的轮廓
horizontal_contours = []
for contour in contours:
# 计算轮廓的边界框
x, y, w, h = cv2.boundingRect(contour)
# 计算轮廓的宽高比
aspect_ratio = float(w) / max(h, 1)
# 在搜索区域内且宽高比大于1更宽而非更高的轮廓更可能是横向线段
if (left_bound <= x + w // 2 <= right_bound and
top_bound <= y + h // 2 <= bottom_bound and
aspect_ratio > 1.0):
horizontal_contours.append(contour)
if not horizontal_contours:
if observe:
print("未找到符合条件的横向轮廓")
# 如果没有找到符合条件的横向轮廓,尝试使用所有在搜索区域内的轮廓
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
if (left_bound <= x + w // 2 <= right_bound and
top_bound <= y + h // 2 <= bottom_bound):
horizontal_contours.append(contour)
if not horizontal_contours:
if observe:
print("在搜索区域内未找到任何轮廓")
return None, None
if observe:
print(f"步骤4.1: 找到 {len(horizontal_contours)} 个可能的横向轮廓")
horizontal_img = img.copy()
cv2.drawContours(horizontal_img, horizontal_contours, -1, (0, 255, 0), 2)
cv2.imshow("横向轮廓", horizontal_img)
cv2.waitKey(delay)
# 收集所有可能的横向轮廓点
all_horizontal_points = []
for contour in horizontal_contours:
for point in contour:
x, y = point[0]
if (left_bound <= x <= right_bound and
top_bound <= y <= bottom_bound):
all_horizontal_points.append((x, y))
if not all_horizontal_points:
if observe:
print("在搜索区域内未找到有效点")
return None, None
# 按y值对点进行分组针对不同的水平线段
# 使用聚类方法将点按y值分组
y_values = np.array([p[1] for p in all_horizontal_points])
y_values = y_values.reshape(-1, 1) # 转换为列向量
# 如果点较少直接按y值简单分组
if len(y_values) < 10:
# 简单分组通过y值差异判断是否属于同一水平线
y_groups = []
current_group = [all_horizontal_points[0]]
current_y = all_horizontal_points[0][1]
for i in range(1, len(all_horizontal_points)):
point = all_horizontal_points[i]
if abs(point[1] - current_y) < 10: # 如果y值接近当前组的y值
current_group.append(point)
else:
y_groups.append(current_group)
current_group = [point]
current_y = point[1]
if current_group:
y_groups.append(current_group)
else:
# 使用K-means聚类按y值将点分为不同组
max_clusters = min(5, len(y_values) // 2) # 最多5个聚类或点数的一半
# 尝试不同数量的聚类,找到最佳分组
best_score = -1
best_labels = None
for n_clusters in range(1, max_clusters + 1):
kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=0).fit(y_values)
score = silhouette_score(y_values, kmeans.labels_) if n_clusters > 1 else 0
if score > best_score:
best_score = score
best_labels = kmeans.labels_
# 根据聚类结果分组
y_groups = [[] for _ in range(max(best_labels) + 1)]
for i, point in enumerate(all_horizontal_points):
group_idx = best_labels[i]
y_groups[group_idx].append(point)
if observe:
clusters_img = img.copy()
colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0), (255, 255, 0), (0, 255, 255)]
for i, group in enumerate(y_groups):
color = colors[i % len(colors)]
for point in group:
cv2.circle(clusters_img, point, 3, color, -1)
cv2.imshow("按Y值分组的点", clusters_img)
cv2.waitKey(delay)
# 为每个组计算平均y值
avg_y_values = []
for group in y_groups:
avg_y = sum(p[1] for p in group) / len(group)
avg_y_values.append((avg_y, group))
# 按平均y值降序排序越大的y值越靠近底部也就是越靠近相机
avg_y_values.sort(reverse=True)
# 从y值最大的组开始分析找到符合横向赛道特征的组
selected_group = None
selected_slope = 0
for avg_y, group in avg_y_values:
# 计算该组点的斜率
if len(group) < 2:
continue
x_coords = np.array([p[0] for p in group])
y_coords = np.array([p[1] for p in group])
if np.std(x_coords) <= 0:
continue
slope, _ = np.polyfit(x_coords, y_coords, 1)
# 判断该组是否可能是横向赛道
# 横向赛道的斜率应该比较小(接近水平)
if abs(slope) < 0.5: # 允许一定的倾斜
selected_group = group
selected_slope = slope
break
# 如果没有找到符合条件的组使用y值最大的组
if selected_group is None and avg_y_values:
selected_group = avg_y_values[0][1]
# 重新计算斜率
if len(selected_group) >= 2:
x_coords = np.array([p[0] for p in selected_group])
y_coords = np.array([p[1] for p in selected_group])
if np.std(x_coords) > 0:
selected_slope, _ = np.polyfit(x_coords, y_coords, 1)
if selected_group is None:
if observe:
print("未能找到有效的横向赛道线")
return None, None
# 找出选定组中y值最大的点最靠近相机的点
bottom_edge_point = max(selected_group, key=lambda p: p[1])
if observe:
print(f"步骤5: 找到边缘点 {bottom_edge_point}")
edge_img = img.copy()
# 绘制选定的组
for point in selected_group:
cv2.circle(edge_img, point, 3, (255, 0, 0), -1)
# 标记边缘点
cv2.circle(edge_img, bottom_edge_point, 10, (0, 0, 255), -1)
cv2.imshow("选定的横向线和边缘点", edge_img)
cv2.waitKey(delay)
# 计算这个点到中线的距离
distance_to_center = bottom_edge_point[0] - center_x
# 改进斜率计算使用BFS找到同一条边缘线上的更多点
def get_better_slope(start_point, points, max_distance=20):
"""使用BFS算法寻找同一条边缘线上的点并计算更准确的斜率"""
queue = [start_point]
visited = {start_point}
line_points = [start_point]
# BFS搜索相连的点
while queue and len(line_points) < 200: # 增加最大点数
current = queue.pop(0)
cx, cy = current
# 对所有未访问点计算距离
for point in points:
if point in visited:
continue
px, py = point
# 计算欧氏距离
dist = np.sqrt((px - cx) ** 2 + (py - cy) ** 2)
# 如果距离在阈值内,认为是同一条线上的点
# 降低距离阈值,使连接更精确
if dist < max_distance:
queue.append(point)
visited.add(point)
line_points.append(point)
# 如果找到足够多的点,计算斜率
if len(line_points) >= 5: # 至少需要更多点来拟合
x_coords = np.array([p[0] for p in line_points])
y_coords = np.array([p[1] for p in line_points])
# 使用RANSAC算法拟合直线更加鲁棒
# 尝试使用RANSAC进行更鲁棒的拟合
try:
# 创建RANSAC对象
ransac = linear_model.RANSACRegressor()
X = x_coords.reshape(-1, 1)
# 拟合模型
ransac.fit(X, y_coords)
new_slope = ransac.estimator_.coef_[0]
# 获取内点(符合模型的点)
inlier_mask = ransac.inlier_mask_
inlier_points = [line_points[i] for i in range(len(line_points)) if inlier_mask[i]]
# 至少需要3个内点
if len(inlier_points) >= 3:
return new_slope, inlier_points
except:
# 如果RANSAC失败回退到普通拟合
pass
# 标准拟合方法作为后备
if np.std(x_coords) > 0:
new_slope, _ = np.polyfit(x_coords, y_coords, 1)
return new_slope, line_points
return selected_slope, line_points
# 尝试获取更准确的斜率
improved_slope, better_line_points = get_better_slope(bottom_edge_point, selected_group)
# 使用改进后的斜率
slope = improved_slope
if observe:
improved_slope_img = img.copy()
# 画出底部边缘点
cv2.circle(improved_slope_img, bottom_edge_point, 10, (0, 0, 255), -1)
# 画出改进后找到的所有点
for point in better_line_points:
cv2.circle(improved_slope_img, point, 3, (255, 255, 0), -1)
# 使用改进后的斜率画线
line_length = 300
# 确保线条经过边缘点
mid_x = bottom_edge_point[0]
mid_y = bottom_edge_point[1]
# 计算线条起点和终点
end_x = mid_x + line_length
end_y = int(mid_y + improved_slope * line_length)
start_x = mid_x - line_length
start_y = int(mid_y - improved_slope * line_length)
# 绘制线条
cv2.line(improved_slope_img, (start_x, start_y), (end_x, end_y), (0, 255, 0), 2)
# 添加文本显示信息
cv2.putText(improved_slope_img, f"原始斜率: {selected_slope:.4f}", (10, 150),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
cv2.putText(improved_slope_img, f"改进斜率: {improved_slope:.4f}", (10, 190),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
cv2.putText(improved_slope_img, f"找到点数: {len(better_line_points)}", (10, 230),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
# 显示所有原始点和改进算法选择的点之间的比较
cv2.putText(improved_slope_img, f"原始点数: {len(selected_group)}", (10, 270),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
cv2.imshow("改进的斜率计算", improved_slope_img)
cv2.waitKey(delay)
# 计算中线与检测到的横向线的交点
# 横向线方程: y = slope * (x - edge_x) + edge_y
# 中线方程: x = center_x
# 解这个方程组得到交点坐标
edge_x, edge_y = bottom_edge_point
intersection_x = center_x
intersection_y = slope * (center_x - edge_x) + edge_y
intersection_point = (int(intersection_x), int(intersection_y))
# 计算交点到图像底部的距离(以像素为单位)
distance_to_bottom = height - intersection_y
if observe:
slope_img = img.copy()
# 画出底部边缘点
cv2.circle(slope_img, bottom_edge_point, 10, (0, 0, 255), -1)
# 画出选定组中的所有点
for point in selected_group:
cv2.circle(slope_img, point, 3, (255, 0, 0), -1)
# 使用斜率画一条线来表示边缘方向
line_length = 200
end_x = bottom_edge_point[0] + line_length
end_y = int(bottom_edge_point[1] + slope * line_length)
start_x = bottom_edge_point[0] - line_length
start_y = int(bottom_edge_point[1] - slope * line_length)
cv2.line(slope_img, (start_x, start_y), (end_x, end_y), (0, 255, 0), 2)
# 画出中线
cv2.line(slope_img, (center_x, 0), (center_x, height), (0, 0, 255), 2)
# 标记中线与横向线的交点 (高亮显示)
cv2.circle(slope_img, intersection_point, 12, (255, 0, 255), -1)
cv2.circle(slope_img, intersection_point, 5, (255, 255, 255), -1)
# 画出交点到底部的距离线
cv2.line(slope_img, intersection_point, (intersection_x, height), (255, 255, 0), 2)
cv2.putText(slope_img, f"Slope: {slope:.4f}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.putText(slope_img, f"Distance to center: {distance_to_center}px", (10, 70),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.putText(slope_img, f"Distance to bottom: {distance_to_bottom:.1f}px", (10, 110),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.putText(slope_img, f"中线交点: ({intersection_point[0]}, {intersection_point[1]})", (10, 150),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("边缘斜率和中线交点", slope_img)
cv2.imwrite("res/path/test/edge_img.png", slope_img)
cv2.waitKey(delay)
# 创建边缘信息字典
edge_info = {
"x": bottom_edge_point[0],
"y": bottom_edge_point[1],
"distance_to_center": distance_to_center,
"slope": slope,
"is_horizontal": abs(slope) < 0.05, # 判断边缘是否接近水平
"points_count": len(selected_group), # 该组中点的数量
"intersection_point": intersection_point, # 中线与横向线的交点
"distance_to_bottom": distance_to_bottom, # 交点到图像底部的距离
"points": selected_group # 添加选定的点组
}
return bottom_edge_point, edge_info
# 用法示例
if __name__ == "__main__":
pass