"""Template matching using OpenCV to find assets in screenshots."""
import os
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
from .types import AssetMatch, BoundingBox
def load_image(path: str) -> np.ndarray:
"""Load an image, handling both regular images and GIFs."""
path = Path(path)
# Handle GIFs with PIL first (OpenCV doesn't handle animated GIFs well)
if path.suffix.lower() == '.gif':
with Image.open(path) as img:
# Convert to RGB (GIFs might be palette-based)
img = img.convert('RGBA')
# Convert to numpy array then to BGR for OpenCV
arr = np.array(img)
# RGBA to BGRA
if arr.shape[2] == 4:
return cv2.cvtColor(arr, cv2.COLOR_RGBA2BGRA)
return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
# Regular images
img = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
if img is None:
raise ValueError(f"Could not load image: {path}")
return img
def find_asset_in_screenshot(
screenshot_path: str,
asset_path: str,
threshold: float = 0.8,
) -> AssetMatch | None:
"""
Find a single asset within a screenshot using template matching.
Args:
screenshot_path: Path to the screenshot image
asset_path: Path to the asset/template image to find
threshold: Minimum confidence score (0-1) for a match
Returns:
AssetMatch if found above threshold, None otherwise
"""
# Load images
screenshot = load_image(screenshot_path)
template = load_image(asset_path)
# Get dimensions
template_h, template_w = template.shape[:2]
# Convert to same format for matching
# Handle alpha channel if present
if len(screenshot.shape) == 3 and screenshot.shape[2] == 4:
screenshot_gray = cv2.cvtColor(screenshot, cv2.COLOR_BGRA2GRAY)
elif len(screenshot.shape) == 3:
screenshot_gray = cv2.cvtColor(screenshot, cv2.COLOR_BGR2GRAY)
else:
screenshot_gray = screenshot
if len(template.shape) == 3 and template.shape[2] == 4:
template_gray = cv2.cvtColor(template, cv2.COLOR_BGRA2GRAY)
elif len(template.shape) == 3:
template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
else:
template_gray = template
# Template matching
result = cv2.matchTemplate(screenshot_gray, template_gray, cv2.TM_CCOEFF_NORMED)
# Find best match
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
if max_val >= threshold:
return AssetMatch(
asset_path=asset_path,
asset_name=os.path.basename(asset_path),
bbox=BoundingBox(
x=max_loc[0],
y=max_loc[1],
width=template_w,
height=template_h,
),
confidence=float(max_val),
)
return None
def find_all_assets(
screenshot_path: str,
asset_paths: list[str],
threshold: float = 0.8,
) -> list[AssetMatch]:
"""
Find all provided assets within a screenshot.
Args:
screenshot_path: Path to the screenshot image
asset_paths: List of paths to asset images to find
threshold: Minimum confidence score (0-1) for matches
Returns:
List of AssetMatch for all found assets
"""
matches = []
for asset_path in asset_paths:
match = find_asset_in_screenshot(screenshot_path, asset_path, threshold)
if match:
matches.append(match)
return matches
def get_screenshot_dimensions(screenshot_path: str) -> tuple[int, int]:
"""Get the dimensions of a screenshot."""
img = load_image(screenshot_path)
h, w = img.shape[:2]
return (w, h)