195 lines
7.0 KiB
Python
Executable File
195 lines
7.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import os
|
||
import argparse
|
||
import matplotlib.pyplot as plt
|
||
|
||
def analyze_sky_ratio(image_path, method="threshold", debug=False, save_result=False):
|
||
"""
|
||
分析图片中天空占整个图片的比例,支持多种检测方法
|
||
|
||
参数:
|
||
image_path: 图片路径
|
||
method: 检测方法,可选 "threshold"(阈值法), "kmeans"(K均值聚类), "gradient"(梯度法)
|
||
debug: 是否显示处理过程中的图像,用于调试
|
||
save_result: 是否保存处理结果图像
|
||
|
||
返回:
|
||
sky_ratio: 天空占比(0-1之间的浮点数)
|
||
"""
|
||
# 读取图片
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
raise ValueError(f"无法读取图片: {image_path}")
|
||
|
||
# 获取图片文件名(不带路径和扩展名)
|
||
filename = os.path.splitext(os.path.basename(image_path))[0]
|
||
|
||
# 根据不同方法检测天空
|
||
if method == "threshold":
|
||
sky_mask, sky_ratio = threshold_method(img)
|
||
elif method == "kmeans":
|
||
sky_mask, sky_ratio = kmeans_method(img)
|
||
elif method == "gradient":
|
||
sky_mask, sky_ratio = gradient_method(img)
|
||
else:
|
||
raise ValueError(f"不支持的检测方法: {method}")
|
||
|
||
# 在原图上标记天空区域
|
||
result = img.copy()
|
||
overlay = img.copy()
|
||
overlay[sky_mask > 0] = [0, 255, 255] # 用黄色标记天空区域
|
||
cv2.addWeighted(overlay, 0.4, img, 0.6, 0, result) # 半透明效果
|
||
|
||
# 显示检测结果信息
|
||
cv2.putText(result, f"Sky Ratio: {sky_ratio:.2%}", (10, 30),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
||
|
||
# 调试模式:显示处理过程图像
|
||
if debug:
|
||
plt.figure(figsize=(12, 8))
|
||
|
||
plt.subplot(221)
|
||
plt.title("Original Image")
|
||
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||
|
||
plt.subplot(222)
|
||
plt.title("Sky Mask")
|
||
plt.imshow(sky_mask, cmap='gray')
|
||
|
||
plt.subplot(223)
|
||
plt.title("Sky Detection Result")
|
||
plt.imshow(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
|
||
|
||
# 直方图分析
|
||
plt.subplot(224)
|
||
plt.title("Grayscale Histogram")
|
||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||
plt.hist(gray.ravel(), 256, [0, 256])
|
||
plt.xlim([0, 256])
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
# 保存结果
|
||
if save_result:
|
||
result_dir = "results"
|
||
os.makedirs(result_dir, exist_ok=True)
|
||
output_path = os.path.join(result_dir, f"{filename}_{method}_result.jpg")
|
||
cv2.imwrite(output_path, result)
|
||
print(f"结果已保存至: {output_path}")
|
||
|
||
return sky_ratio
|
||
|
||
def threshold_method(img, threshold=180):
|
||
"""使用简单阈值法检测天空"""
|
||
# 转换为灰度图
|
||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||
|
||
# 使用阈值分割天空区域(天空通常是图像中较亮的部分)
|
||
_, sky_mask = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY)
|
||
|
||
# 应用形态学操作,去除噪点
|
||
kernel = np.ones((5, 5), np.uint8)
|
||
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_OPEN, kernel)
|
||
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_CLOSE, kernel)
|
||
|
||
# 计算天空区域占比
|
||
total_pixels = gray.shape[0] * gray.shape[1]
|
||
sky_pixels = np.sum(sky_mask == 255)
|
||
sky_ratio = sky_pixels / total_pixels
|
||
|
||
return sky_mask, sky_ratio
|
||
|
||
def kmeans_method(img, k=3):
|
||
"""使用K均值聚类检测天空"""
|
||
# 转换为RGB格式用于聚类
|
||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
|
||
# 重塑图像为二维数组
|
||
pixels = img_rgb.reshape((-1, 3)).astype(np.float32)
|
||
|
||
# 定义K均值聚类终止条件
|
||
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
|
||
|
||
# 应用K均值聚类
|
||
_, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
|
||
|
||
# 找出最亮的聚类中心(可能是天空)
|
||
centers = centers.astype(np.uint8)
|
||
brightness = np.mean(centers, axis=1)
|
||
sky_cluster = np.argmax(brightness)
|
||
|
||
# 创建掩码:仅保留最亮聚类的像素
|
||
sky_mask = np.zeros(img_rgb.shape[:2], dtype=np.uint8)
|
||
labels = labels.reshape(img_rgb.shape[:2])
|
||
sky_mask[labels == sky_cluster] = 255
|
||
|
||
# 应用形态学操作,去除噪点
|
||
kernel = np.ones((5, 5), np.uint8)
|
||
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_OPEN, kernel)
|
||
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_CLOSE, kernel)
|
||
|
||
# 计算天空区域占比
|
||
total_pixels = sky_mask.shape[0] * sky_mask.shape[1]
|
||
sky_pixels = np.sum(sky_mask == 255)
|
||
sky_ratio = sky_pixels / total_pixels
|
||
|
||
return sky_mask, sky_ratio
|
||
|
||
def gradient_method(img):
|
||
"""使用图像梯度检测天空(假设天空在图像上部且边缘较少)"""
|
||
# 转换为灰度图
|
||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||
|
||
# 使用Sobel算子计算梯度
|
||
sobel_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
|
||
sobel_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
||
|
||
# 计算梯度幅值
|
||
gradient_magnitude = np.sqrt(sobel_x**2 + sobel_y**2)
|
||
gradient_magnitude = cv2.normalize(gradient_magnitude, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
|
||
|
||
# 假设天空区域梯度较小(纹理较少)
|
||
_, gradient_mask = cv2.threshold(gradient_magnitude, 50, 255, cv2.THRESH_BINARY_INV)
|
||
|
||
# 假设天空通常位于图像上部
|
||
height, width = gradient_mask.shape
|
||
sky_region_height = int(height * 0.6) # 假设天空在上部60%的区域
|
||
|
||
# 创建天空掩码
|
||
sky_mask = np.zeros_like(gradient_mask)
|
||
sky_mask[:sky_region_height, :] = gradient_mask[:sky_region_height, :]
|
||
|
||
# 应用形态学操作
|
||
kernel = np.ones((5, 5), np.uint8)
|
||
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_OPEN, kernel)
|
||
sky_mask = cv2.morphologyEx(sky_mask, cv2.MORPH_CLOSE, kernel)
|
||
|
||
# 计算天空区域占比
|
||
total_pixels = gray.shape[0] * gray.shape[1]
|
||
sky_pixels = np.sum(sky_mask == 255)
|
||
sky_ratio = sky_pixels / total_pixels
|
||
|
||
return sky_mask, sky_ratio
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='分析图片中天空区域占比')
|
||
parser.add_argument('image_path', help='图片路径')
|
||
parser.add_argument('--method', default='threshold', choices=['threshold', 'kmeans', 'gradient'],
|
||
help='天空检测方法: threshold(阈值法), kmeans(K均值聚类), gradient(梯度法)')
|
||
parser.add_argument('--debug', action='store_true', help='显示处理过程图像')
|
||
parser.add_argument('--save', action='store_true', help='保存处理结果图像')
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
sky_ratio = analyze_sky_ratio(args.image_path, args.method, args.debug, args.save)
|
||
print(f"使用{args.method}方法检测的天空区域占比: {sky_ratio:.2%}")
|
||
except Exception as e:
|
||
print(f"错误: {e}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |