mi-task/utils/bar-sky/advanced_sky_analyzer.py

195 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import cv2
import numpy as np
import os
import argparse
import matplotlib.pyplot as plt
def analyze_sky_ratio(image_path, method="threshold", debug=False, save_result=False):
"""
分析图片中天空占整个图片的比例,支持多种检测方法
参数:
image_path: 图片路径
method: 检测方法,可选 "threshold"(阈值法), "kmeans"(K均值聚类), "gradient"(梯度法)
debug: 是否显示处理过程中的图像,用于调试
save_result: 是否保存处理结果图像
返回:
sky_ratio: 天空占比0-1之间的浮点数
"""
# 读取图片
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"无法读取图片: {image_path}")
# 获取图片文件名(不带路径和扩展名)
filename = os.path.splitext(os.path.basename(image_path))[0]
# 根据不同方法检测天空
if method == "threshold":
sky_mask, sky_ratio = threshold_method(img)
elif method == "kmeans":
sky_mask, sky_ratio = kmeans_method(img)
elif method == "gradient":
sky_mask, sky_ratio = gradient_method(img)
else:
raise ValueError(f"不支持的检测方法: {method}")
# 在原图上标记天空区域
result = img.copy()
overlay = img.copy()
overlay[sky_mask > 0] = [0, 255, 255] # 用黄色标记天空区域
cv2.addWeighted(overlay, 0.4, img, 0.6, 0, result) # 半透明效果
# 显示检测结果信息
cv2.putText(result, f"Sky Ratio: {sky_ratio:.2%}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
# 调试模式:显示处理过程图像
if debug:
plt.figure(figsize=(12, 8))
plt.subplot(221)
plt.title("Original Image")
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.subplot(222)
plt.title("Sky Mask")
plt.imshow(sky_mask, cmap='gray')
plt.subplot(223)
plt.title("Sky Detection Result")
plt.imshow(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
# 直方图分析
plt.subplot(224)
plt.title("Grayscale Histogram")
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
plt.hist(gray.ravel(), 256, [0, 256])
plt.xlim([0, 256])
plt.tight_layout()
plt.show()
# 保存结果
if save_result:
result_dir = "results"
os.makedirs(result_dir, exist_ok=True)
output_path = os.path.join(result_dir, f"{filename}_{method}_result.jpg")
cv2.imwrite(output_path, result)
print(f"结果已保存至: {output_path}")
return sky_ratio
def threshold_method(img, threshold=180):
"""使用简单阈值法检测天空"""
# 转换为灰度图
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 使用阈值分割天空区域(天空通常是图像中较亮的部分)
_, sky_mask = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY)
# 应用形态学操作,去除噪点
kernel = np.ones((5, 5), np.uint8)
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_OPEN, kernel)
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_CLOSE, kernel)
# 计算天空区域占比
total_pixels = gray.shape[0] * gray.shape[1]
sky_pixels = np.sum(sky_mask == 255)
sky_ratio = sky_pixels / total_pixels
return sky_mask, sky_ratio
def kmeans_method(img, k=3):
"""使用K均值聚类检测天空"""
# 转换为RGB格式用于聚类
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 重塑图像为二维数组
pixels = img_rgb.reshape((-1, 3)).astype(np.float32)
# 定义K均值聚类终止条件
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
# 应用K均值聚类
_, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
# 找出最亮的聚类中心(可能是天空)
centers = centers.astype(np.uint8)
brightness = np.mean(centers, axis=1)
sky_cluster = np.argmax(brightness)
# 创建掩码:仅保留最亮聚类的像素
sky_mask = np.zeros(img_rgb.shape[:2], dtype=np.uint8)
labels = labels.reshape(img_rgb.shape[:2])
sky_mask[labels == sky_cluster] = 255
# 应用形态学操作,去除噪点
kernel = np.ones((5, 5), np.uint8)
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_OPEN, kernel)
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_CLOSE, kernel)
# 计算天空区域占比
total_pixels = sky_mask.shape[0] * sky_mask.shape[1]
sky_pixels = np.sum(sky_mask == 255)
sky_ratio = sky_pixels / total_pixels
return sky_mask, sky_ratio
def gradient_method(img):
"""使用图像梯度检测天空(假设天空在图像上部且边缘较少)"""
# 转换为灰度图
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 使用Sobel算子计算梯度
sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
# 计算梯度幅值
gradient_magnitude = np.sqrt(sobel_x**2 + sobel_y**2)
gradient_magnitude = cv2.normalize(gradient_magnitude, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
# 假设天空区域梯度较小(纹理较少)
_, gradient_mask = cv2.threshold(gradient_magnitude, 50, 255, cv2.THRESH_BINARY_INV)
# 假设天空通常位于图像上部
height, width = gradient_mask.shape
sky_region_height = int(height * 0.6) # 假设天空在上部60%的区域
# 创建天空掩码
sky_mask = np.zeros_like(gradient_mask)
sky_mask[:sky_region_height, :] = gradient_mask[:sky_region_height, :]
# 应用形态学操作
kernel = np.ones((5, 5), np.uint8)
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_OPEN, kernel)
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_CLOSE, kernel)
# 计算天空区域占比
total_pixels = gray.shape[0] * gray.shape[1]
sky_pixels = np.sum(sky_mask == 255)
sky_ratio = sky_pixels / total_pixels
return sky_mask, sky_ratio
def main():
parser = argparse.ArgumentParser(description='分析图片中天空区域占比')
parser.add_argument('image_path', help='图片路径')
parser.add_argument('--method', default='threshold', choices=['threshold', 'kmeans', 'gradient'],
help='天空检测方法: threshold(阈值法), kmeans(K均值聚类), gradient(梯度法)')
parser.add_argument('--debug', action='store_true', help='显示处理过程图像')
parser.add_argument('--save', action='store_true', help='保存处理结果图像')
args = parser.parse_args()
try:
sky_ratio = analyze_sky_ratio(args.image_path, args.method, args.debug, args.save)
print(f"使用{args.method}方法检测的天空区域占比: {sky_ratio:.2%}")
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
main()