"""Layout analysis - detect spatial relationships between elements."""
import math
from dataclasses import dataclass
from .types import (
AssetMatch,
ElementPosition,
GridInfo,
LayoutAnalysis,
LayoutPattern,
RadialInfo,
SidebarInfo,
StackedInfo,
)
from .template_match import find_all_assets, get_screenshot_dimensions
def calculate_angle(center: tuple[float, float], point: tuple[float, float]) -> float:
"""
Calculate angle from center to point in degrees.
0 = right, 90 = down, 180 = left, 270 = up
"""
dx = point[0] - center[0]
dy = point[1] - center[1]
angle = math.degrees(math.atan2(dy, dx))
return angle if angle >= 0 else angle + 360
def calculate_distance(p1: tuple[float, float], p2: tuple[float, float]) -> float:
"""Calculate Euclidean distance between two points."""
return math.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2)
def cluster_values(values: list[float], threshold: float) -> list[list[float]]:
"""Group values that are within threshold of each other."""
if not values:
return []
sorted_vals = sorted(values)
clusters = [[sorted_vals[0]]]
for v in sorted_vals[1:]:
if v - clusters[-1][-1] <= threshold:
clusters[-1].append(v)
else:
clusters.append([v])
return clusters
def cluster_centers(clusters: list[list[float]]) -> list[float]:
"""Get center value of each cluster."""
return [sum(c) / len(c) for c in clusters]
# =============================================================================
# Pattern Detection Functions
# =============================================================================
@dataclass
class PatternScore:
"""Score for a detected pattern."""
pattern: LayoutPattern
confidence: float
info: dict # Pattern-specific data
def detect_radial_pattern(
matches: list[AssetMatch],
viewport_center: tuple[float, float],
) -> PatternScore:
"""
Detect if elements are arranged radially around a center point.
"""
if len(matches) < 3:
return PatternScore(LayoutPattern.RADIAL, 0.0, {})
# Find potential center element (large + near center)
max_distance = math.sqrt(viewport_center[0]**2 + viewport_center[1]**2)
areas = [m.bbox.area for m in matches]
max_area = max(areas) if areas else 1
def center_score(match: AssetMatch) -> float:
dist = calculate_distance(match.bbox.center, viewport_center)
dist_score = 1 - (dist / max_distance)
size_score = match.bbox.area / max_area
return dist_score * 0.4 + size_score * 0.6
# Find best center candidate
center_element = max(matches, key=center_score)
center_point = center_element.bbox.center
# Check if center element is actually near center
center_dist = calculate_distance(center_point, viewport_center)
if center_dist > max_distance * 0.35:
# Center element too far from viewport center - not radial
return PatternScore(LayoutPattern.RADIAL, 0.0, {})
# Get other elements and their distances from center
other_matches = [m for m in matches if m != center_element]
if len(other_matches) < 2:
return PatternScore(LayoutPattern.RADIAL, 0.0, {})
distances = [calculate_distance(m.bbox.center, center_point) for m in other_matches]
mean_dist = sum(distances) / len(distances)
if mean_dist == 0:
return PatternScore(LayoutPattern.RADIAL, 0.0, {})
# Check distance consistency (coefficient of variation)
variance = sum((d - mean_dist) ** 2 for d in distances) / len(distances)
std_dev = math.sqrt(variance)
cv = std_dev / mean_dist
# Radial confidence based on distance consistency
# CV < 0.3 = consistent distances, CV > 0.6 = very inconsistent
confidence = max(0, min(1, 1 - cv * 1.5))
# Bonus: Check if elements are spread around (not clustered on one side)
angles = [calculate_angle(center_point, m.bbox.center) for m in other_matches]
angle_spread = max(angles) - min(angles) if angles else 0
if angle_spread > 180:
confidence *= 1.1 # Bonus for good spread
confidence = min(1, confidence)
return PatternScore(
LayoutPattern.RADIAL,
confidence,
{
"center_element": center_element,
"center_point": center_point,
"average_radius": mean_dist,
"other_matches": other_matches,
}
)
def detect_grid_pattern(matches: list[AssetMatch]) -> PatternScore:
"""
Detect if elements are arranged in a grid.
"""
if len(matches) < 4:
return PatternScore(LayoutPattern.GRID, 0.0, {})
# Cluster X and Y positions
x_coords = [m.bbox.center_x for m in matches]
y_coords = [m.bbox.center_y for m in matches]
# Use element sizes to determine clustering threshold
avg_width = sum(m.bbox.width for m in matches) / len(matches)
avg_height = sum(m.bbox.height for m in matches) / len(matches)
x_clusters = cluster_values(x_coords, avg_width * 0.5)
y_clusters = cluster_values(y_coords, avg_height * 0.5)
cols = len(x_clusters)
rows = len(y_clusters)
if cols < 2 or rows < 1:
return PatternScore(LayoutPattern.GRID, 0.0, {})
# Check grid fill ratio
expected_elements = cols * rows
actual_elements = len(matches)
fill_ratio = actual_elements / expected_elements
# Check spacing consistency
def spacing_consistency(clusters: list[list[float]]) -> float:
centers = cluster_centers(clusters)
if len(centers) < 2:
return 1.0
gaps = [centers[i + 1] - centers[i] for i in range(len(centers) - 1)]
mean_gap = sum(gaps) / len(gaps)
if mean_gap == 0:
return 0.0
variance = sum((g - mean_gap) ** 2 for g in gaps) / len(gaps)
cv = math.sqrt(variance) / mean_gap
return max(0, 1 - cv)
x_consistency = spacing_consistency(x_clusters)
y_consistency = spacing_consistency(y_clusters)
# Combined confidence
confidence = fill_ratio * 0.4 + x_consistency * 0.3 + y_consistency * 0.3
# Must have decent fill and consistency
if fill_ratio < 0.5 or (x_consistency < 0.5 and y_consistency < 0.5):
confidence *= 0.5
# Calculate grid structure
x_centers = cluster_centers(x_clusters)
y_centers = cluster_centers(y_clusters)
# Estimate gaps
x_gaps = [x_centers[i + 1] - x_centers[i] for i in range(len(x_centers) - 1)] if len(x_centers) > 1 else [0]
y_gaps = [y_centers[i + 1] - y_centers[i] for i in range(len(y_centers) - 1)] if len(y_centers) > 1 else [0]
return PatternScore(
LayoutPattern.GRID,
confidence,
{
"columns": cols,
"rows": rows,
"x_centers": x_centers,
"y_centers": y_centers,
"gap_x": sum(x_gaps) / len(x_gaps) - avg_width if x_gaps else 0,
"gap_y": sum(y_gaps) / len(y_gaps) - avg_height if y_gaps else 0,
"x_clusters": x_clusters,
"y_clusters": y_clusters,
}
)
def detect_stacked_pattern(
matches: list[AssetMatch],
viewport_height: int,
) -> PatternScore:
"""
Detect if elements are arranged in vertical sections (header, main, footer).
"""
if len(matches) < 2:
return PatternScore(LayoutPattern.STACKED, 0.0, {})
# Sort by Y position
sorted_matches = sorted(matches, key=lambda m: m.bbox.y)
# Look for clear horizontal bands
y_coords = [m.bbox.center_y for m in matches]
avg_height = sum(m.bbox.height for m in matches) / len(matches)
y_clusters = cluster_values(y_coords, avg_height * 1.5)
if len(y_clusters) < 2:
return PatternScore(LayoutPattern.STACKED, 0.0, {})
# Check if elements span most of the width (characteristic of stacked layouts)
viewport_coverage = []
for cluster in y_clusters:
cluster_matches = [m for m in matches if any(
abs(m.bbox.center_y - y) < avg_height for y in cluster
)]
if cluster_matches:
min_x = min(m.bbox.x for m in cluster_matches)
max_x = max(m.bbox.right for m in cluster_matches)
coverage = (max_x - min_x) / viewport_height # Rough width estimate
viewport_coverage.append(coverage)
# Stacked layouts typically have elements spanning width
avg_coverage = sum(viewport_coverage) / len(viewport_coverage) if viewport_coverage else 0
# Confidence based on number of distinct rows and coverage
row_score = min(1, len(y_clusters) / 3) # 3+ rows is good
coverage_score = min(1, avg_coverage * 2) # Higher coverage = more stacked-like
confidence = row_score * 0.6 + coverage_score * 0.4
# Identify sections based on Y position
sections = []
y_centers = cluster_centers(y_clusters)
section_names = ["header", "main", "footer"] if len(y_centers) <= 3 else \
[f"section_{i}" for i in range(len(y_centers))]
for i, (y_center, cluster) in enumerate(zip(y_centers, y_clusters)):
name = section_names[i] if i < len(section_names) else f"section_{i}"
sections.append({
"name": name,
"y_center": y_center,
"y_range": (min(cluster), max(cluster)),
})
return PatternScore(
LayoutPattern.STACKED,
confidence,
{
"sections": sections,
"y_clusters": y_clusters,
}
)
def detect_sidebar_pattern(
matches: list[AssetMatch],
viewport_width: int,
) -> PatternScore:
"""
Detect if elements are arranged in a sidebar + main content layout.
"""
if len(matches) < 3:
return PatternScore(LayoutPattern.SIDEBAR, 0.0, {})
# Check for vertical split
x_coords = [m.bbox.center_x for m in matches]
midpoint = viewport_width / 2
left_elements = [m for m in matches if m.bbox.center_x < midpoint * 0.6]
right_elements = [m for m in matches if m.bbox.center_x > midpoint * 1.4]
center_elements = [m for m in matches if midpoint * 0.6 <= m.bbox.center_x <= midpoint * 1.4]
# Sidebar pattern: one side has narrow column, other has wider content
left_spread = max(m.bbox.right for m in left_elements) - min(m.bbox.x for m in left_elements) if left_elements else 0
right_spread = max(m.bbox.right for m in right_elements) - min(m.bbox.x for m in right_elements) if right_elements else 0
# Detect sidebar if there's asymmetry
if not left_elements or not right_elements:
return PatternScore(LayoutPattern.SIDEBAR, 0.0, {})
width_ratio = min(left_spread, right_spread) / max(left_spread, right_spread) if max(left_spread, right_spread) > 0 else 1
# Sidebar typically takes 20-35% of width
# Main content takes 65-80%
if width_ratio > 0.7: # Too symmetric
return PatternScore(LayoutPattern.SIDEBAR, 0.0, {})
# Determine which side is sidebar
if left_spread < right_spread:
sidebar_side = "left"
sidebar_width = left_spread
main_width = right_spread
sidebar_elements = left_elements
main_elements = right_elements
else:
sidebar_side = "right"
sidebar_width = right_spread
main_width = left_spread
sidebar_elements = right_elements
main_elements = left_elements
# Check vertical distribution in sidebar (should be stacked vertically)
sidebar_y_spread = max(m.bbox.bottom for m in sidebar_elements) - min(m.bbox.y for m in sidebar_elements) if sidebar_elements else 0
# Confidence based on asymmetry and vertical spread
asymmetry_score = 1 - width_ratio
vertical_score = min(1, sidebar_y_spread / (viewport_width * 0.5))
confidence = asymmetry_score * 0.6 + vertical_score * 0.4
return PatternScore(
LayoutPattern.SIDEBAR,
confidence,
{
"sidebar_side": sidebar_side,
"sidebar_width": sidebar_width,
"sidebar_width_percent": (sidebar_width / viewport_width) * 100,
"main_width": main_width,
"main_width_percent": (main_width / viewport_width) * 100,
"sidebar_elements": sidebar_elements,
"main_elements": main_elements,
}
)
# =============================================================================
# Main Analysis Function
# =============================================================================
def analyze_layout(
screenshot_path: str,
asset_paths: list[str],
threshold: float = 0.8,
) -> LayoutAnalysis:
"""
Analyze the layout of assets in a screenshot.
Auto-detects the layout pattern and returns appropriate structure.
"""
# Get screenshot dimensions
viewport_w, viewport_h = get_screenshot_dimensions(screenshot_path)
viewport_center = (viewport_w / 2, viewport_h / 2)
# Find all assets
matches = find_all_assets(screenshot_path, asset_paths, threshold)
if not matches:
return LayoutAnalysis(
viewport_width=viewport_w,
viewport_height=viewport_h,
pattern=LayoutPattern.FREEFORM,
pattern_confidence=0.0,
elements=[],
)
# Run all pattern detectors
radial_score = detect_radial_pattern(matches, viewport_center)
grid_score = detect_grid_pattern(matches)
stacked_score = detect_stacked_pattern(matches, viewport_h)
sidebar_score = detect_sidebar_pattern(matches, viewport_w)
# Find best pattern
scores = [radial_score, grid_score, stacked_score, sidebar_score]
best = max(scores, key=lambda s: s.confidence)
# If no pattern has good confidence, fall back to freeform
if best.confidence < 0.3:
elements = [
ElementPosition(
asset_match=m,
x=m.bbox.x,
y=m.bbox.y,
)
for m in matches
]
return LayoutAnalysis(
viewport_width=viewport_w,
viewport_height=viewport_h,
pattern=LayoutPattern.FREEFORM,
pattern_confidence=1.0 - best.confidence,
elements=elements,
)
# Build response based on detected pattern
if best.pattern == LayoutPattern.RADIAL:
info = best.info
center_point = info["center_point"]
center_element = info["center_element"]
elements = []
for m in matches:
if m == center_element:
continue
angle = calculate_angle(center_point, m.bbox.center)
distance = calculate_distance(center_point, m.bbox.center)
elements.append(ElementPosition(
asset_match=m,
x=m.bbox.x,
y=m.bbox.y,
angle_degrees=angle,
distance_from_center=distance,
))
return LayoutAnalysis(
viewport_width=viewport_w,
viewport_height=viewport_h,
pattern=LayoutPattern.RADIAL,
pattern_confidence=best.confidence,
elements=elements,
radial_info=RadialInfo(
center_x=center_point[0],
center_y=center_point[1],
center_element=center_element,
average_radius=info["average_radius"],
),
)
elif best.pattern == LayoutPattern.GRID:
info = best.info
x_clusters = info["x_clusters"]
y_clusters = info["y_clusters"]
x_centers = cluster_centers(x_clusters)
y_centers = cluster_centers(y_clusters)
elements = []
for m in matches:
# Find grid position
col = min(range(len(x_centers)), key=lambda i: abs(m.bbox.center_x - x_centers[i]))
row = min(range(len(y_centers)), key=lambda i: abs(m.bbox.center_y - y_centers[i]))
elements.append(ElementPosition(
asset_match=m,
x=m.bbox.x,
y=m.bbox.y,
grid_row=row,
grid_column=col,
))
# Calculate widths/heights
avg_width = sum(m.bbox.width for m in matches) / len(matches)
avg_height = sum(m.bbox.height for m in matches) / len(matches)
return LayoutAnalysis(
viewport_width=viewport_w,
viewport_height=viewport_h,
pattern=LayoutPattern.GRID,
pattern_confidence=best.confidence,
elements=elements,
grid_info=GridInfo(
columns=info["columns"],
rows=info["rows"],
column_positions=x_centers,
row_positions=y_centers,
column_widths=[avg_width] * info["columns"],
row_heights=[avg_height] * info["rows"],
gap_x=max(0, info["gap_x"]),
gap_y=max(0, info["gap_y"]),
),
)
elif best.pattern == LayoutPattern.STACKED:
info = best.info
sections = info["sections"]
y_clusters = info["y_clusters"]
elements = []
for m in matches:
# Find section
section_idx = 0
min_dist = float('inf')
for i, section in enumerate(sections):
dist = abs(m.bbox.center_y - section["y_center"])
if dist < min_dist:
min_dist = dist
section_idx = i
section_name = sections[section_idx]["name"]
elements.append(ElementPosition(
asset_match=m,
x=m.bbox.x,
y=m.bbox.y,
section=section_name,
vertical_order=section_idx,
))
# Build section info for output
section_data = []
for section in sections:
section_elements = [e.asset_match.asset_name for e in elements if e.section == section["name"]]
section_data.append({
"name": section["name"],
"y_start": int(section["y_range"][0]),
"y_end": int(section["y_range"][1]),
"elements": section_elements,
})
return LayoutAnalysis(
viewport_width=viewport_w,
viewport_height=viewport_h,
pattern=LayoutPattern.STACKED,
pattern_confidence=best.confidence,
elements=elements,
stacked_info=StackedInfo(sections=section_data),
)
elif best.pattern == LayoutPattern.SIDEBAR:
info = best.info
sidebar_elements = info["sidebar_elements"]
main_elements = info["main_elements"]
elements = []
for m in matches:
if m in sidebar_elements:
region = "sidebar"
elif m in main_elements:
region = "main"
else:
region = "other"
elements.append(ElementPosition(
asset_match=m,
x=m.bbox.x,
y=m.bbox.y,
region=region,
))
return LayoutAnalysis(
viewport_width=viewport_w,
viewport_height=viewport_h,
pattern=LayoutPattern.SIDEBAR,
pattern_confidence=best.confidence,
elements=elements,
sidebar_info=SidebarInfo(
sidebar_side=info["sidebar_side"],
sidebar_width=info["sidebar_width"],
sidebar_width_percent=info["sidebar_width_percent"],
main_width=info["main_width"],
main_width_percent=info["main_width_percent"],
),
)
# Fallback to freeform
elements = [
ElementPosition(
asset_match=m,
x=m.bbox.x,
y=m.bbox.y,
)
for m in matches
]
return LayoutAnalysis(
viewport_width=viewport_w,
viewport_height=viewport_h,
pattern=LayoutPattern.FREEFORM,
pattern_confidence=1.0,
elements=elements,
)