Files
2025-12-11 22:22:12 +08:00

211 lines
7.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
将3x3网格图片切分为9张独立图片
基于行/列像素方差检测分隔线位置
"""
import os
from pathlib import Path
import numpy as np
from PIL import Image
def find_split_line(variance: np.ndarray, start: int, end: int) -> tuple[int, int]:
"""
在指定区间内找到方差最小的连续区域(分隔线位置)
返回 (line_start, line_end)
"""
# 在区间内找最小值位置
region = variance[start:end]
min_idx = start + np.argmin(region)
min_val = variance[min_idx]
# 向两侧扩展直到方差明显上升超过最小值的3倍或绝对阈值
threshold = max(min_val * 3, 50)
line_start = min_idx
while line_start > 0 and variance[line_start - 1] < threshold:
line_start -= 1
line_end = min_idx
while line_end < len(variance) - 1 and variance[line_end + 1] < threshold:
line_end += 1
return line_start, line_end + 1 # 返回左闭右开区间
def find_split_line_by_gradient(gray: np.ndarray, axis: int, start: int, end: int, is_bright_line: bool) -> tuple[int, int]:
"""
基于梯度突变检测分隔线边界
axis: 0=检测水平线, 1=检测垂直线
is_bright_line: True=亮色分隔线(如白色), False=暗色分隔线(如黑色)
"""
if axis == 0:
means = np.mean(gray, axis=1) # 每行的平均亮度
else:
means = np.mean(gray, axis=0) # 每列的平均亮度
# 计算相邻行/列的亮度差异(梯度)
gradient = np.diff(means)
# 找分隔线中心:在区间内找亮度极值
region = means[start:end]
if is_bright_line:
center_idx = start + np.argmax(region) # 亮色线找最亮点
else:
center_idx = start + np.argmin(region) # 暗色线找最暗点
# 从中心向左找边界:寻找梯度符号变化点(亮度突变)
line_start = center_idx
for i in range(center_idx - 1, max(0, center_idx - 30), -1):
# 对于亮色线:边界处 gradient > 0从暗到亮
# 对于暗色线:边界处 gradient < 0从亮到暗
if is_bright_line and gradient[i] > 3:
line_start = i + 1
break
elif not is_bright_line and gradient[i] < -3:
line_start = i + 1
break
line_start = i
# 从中心向右找边界
line_end = center_idx
for i in range(center_idx, min(len(gradient), center_idx + 30)):
if is_bright_line and gradient[i] < -3:
line_end = i + 1
break
elif not is_bright_line and gradient[i] > 3:
line_end = i + 1
break
line_end = i + 1
return line_start, line_end
def detect_split_lines_forest(image: Image.Image) -> tuple[list, list]:
"""专门处理 forest 的白色宽分隔线"""
gray = np.array(image.convert('L'), dtype=np.float32)
h, w = gray.shape
h_line1 = find_split_line_by_gradient(gray, 0, h // 4, h // 2 + h // 8, is_bright_line=True)
h_line2 = find_split_line_by_gradient(gray, 0, h // 2 + h // 8, 3 * h // 4, is_bright_line=True)
v_line1 = find_split_line_by_gradient(gray, 1, w // 4, w // 2 + w // 8, is_bright_line=True)
v_line2 = find_split_line_by_gradient(gray, 1, w // 2 + w // 8, 3 * w // 4, is_bright_line=True)
return [h_line1, h_line2], [v_line1, v_line2]
def detect_split_lines_snow_mountain(image: Image.Image) -> tuple[list, list]:
"""专门处理 snow_mountain 的黑色细分隔线"""
gray = np.array(image.convert('L'), dtype=np.float32)
h, w = gray.shape
h_line1 = find_split_line_by_gradient(gray, 0, h // 4, h // 2 + h // 8, is_bright_line=False)
h_line2 = find_split_line_by_gradient(gray, 0, h // 2 + h // 8, 3 * h // 4, is_bright_line=False)
v_line1 = find_split_line_by_gradient(gray, 1, w // 4, w // 2 + w // 8, is_bright_line=False)
v_line2 = find_split_line_by_gradient(gray, 1, w // 2 + w // 8, 3 * w // 4, is_bright_line=False)
return [h_line1, h_line2], [v_line1, v_line2]
def detect_split_lines(image: Image.Image) -> tuple[list, list]:
"""
检测水平和垂直分隔线的位置
返回 (h_lines, v_lines),每个是 [(start1, end1), (start2, end2)]
"""
gray = np.array(image.convert('L'), dtype=np.float32)
h, w = gray.shape
# 计算每行的方差
row_variance = np.var(gray, axis=1)
# 计算每列的方差
col_variance = np.var(gray, axis=0)
# 平滑处理,减少噪点影响
kernel_size = 3
row_variance = np.convolve(row_variance, np.ones(kernel_size)/kernel_size, mode='same')
col_variance = np.convolve(col_variance, np.ones(kernel_size)/kernel_size, mode='same')
# 在约 1/4~1/2 区间找第一条分隔线1/2~3/4 区间找第二条
h_line1 = find_split_line(row_variance, h // 4, h // 2 + h // 8)
h_line2 = find_split_line(row_variance, h // 2 + h // 8, 3 * h // 4)
v_line1 = find_split_line(col_variance, w // 4, w // 2 + w // 8)
v_line2 = find_split_line(col_variance, w // 2 + w // 8, 3 * w // 4)
return [h_line1, h_line2], [v_line1, v_line2]
def split_image(image: Image.Image, h_lines: list, v_lines: list) -> list[Image.Image]:
"""
根据分隔线位置切分图片为9块
"""
w, h = image.size
# 计算切分边界:[0, line1_start, line1_end, line2_start, line2_end, total]
y_bounds = [0, h_lines[0][0], h_lines[0][1], h_lines[1][0], h_lines[1][1], h]
x_bounds = [0, v_lines[0][0], v_lines[0][1], v_lines[1][0], v_lines[1][1], w]
tiles = []
# 取索引 0, 2, 4 对应的区域(跳过分隔线区域 1, 3
for row_idx in [0, 2, 4]:
for col_idx in [0, 2, 4]:
left = x_bounds[col_idx]
right = x_bounds[col_idx + 1]
top = y_bounds[row_idx]
bottom = y_bounds[row_idx + 1]
tile = image.crop((left, top, right, bottom))
tiles.append(tile)
return tiles
def process_image(input_path: Path, output_dir: Path):
"""处理单张图片"""
image = Image.open(input_path)
name = input_path.stem
# 特殊图像使用专门的检测方法
if name == 'forest':
h_lines, v_lines = detect_split_lines_forest(image)
elif name == 'snow_mountain':
h_lines, v_lines = detect_split_lines_snow_mountain(image)
else:
h_lines, v_lines = detect_split_lines(image)
print(f"{name}: 水平分隔线 {h_lines}, 垂直分隔线 {v_lines}")
# 切分
tiles = split_image(image, h_lines, v_lines)
# 保存
for i, tile in enumerate(tiles):
output_path = output_dir / f"{name}_{i}.png"
tile.save(output_path)
print(f" 保存: {output_path.name} ({tile.size[0]}x{tile.size[1]})")
def main():
script_dir = Path(__file__).parent
input_dir = script_dir / "origin2"
output_dir = script_dir / "split2"
output_dir.mkdir(exist_ok=True)
# 支持的图片格式
extensions = {'.jpg', '.jpeg', '.png'}
for path in sorted(input_dir.iterdir()):
if path.suffix.lower() in extensions:
process_image(path, output_dir)
print(f"\n完成!输出目录: {output_dir}")
if __name__ == "__main__":
main()