import asyncio
import logging
import time
from typing import TYPE_CHECKING
from cdp_use.cdp.accessibility.commands import GetFullAXTreeReturns
from cdp_use.cdp.accessibility.types import AXNode
from cdp_use.cdp.dom.types import Node
from cdp_use.cdp.target import TargetID
from browser_use.dom.enhanced_snapshot import (
REQUIRED_COMPUTED_STYLES,
build_snapshot_lookup,
)
from browser_use.dom.serializer.serializer import DOMTreeSerializer
from browser_use.dom.views import (
DOMRect,
EnhancedAXNode,
EnhancedAXProperty,
EnhancedDOMTreeNode,
NodeType,
SerializedDOMState,
TargetAllTrees,
)
from browser_use.observability import observe_debug
from browser_use.utils import create_task_with_error_handling
if TYPE_CHECKING:
from browser_use.browser.session import BrowserSession
# Note: iframe limits are now configurable via BrowserProfile.max_iframes and BrowserProfile.max_iframe_depth
class DomService:
"""
Service for getting the DOM tree and other DOM-related information.
Either browser or page must be provided.
TODO: currently we start a new websocket connection PER STEP, we should definitely keep this persistent
"""
logger: logging.Logger
def __init__(
self,
browser_session: 'BrowserSession',
logger: logging.Logger | None = None,
cross_origin_iframes: bool = False,
paint_order_filtering: bool = True,
max_iframes: int = 100,
max_iframe_depth: int = 5,
):
self.browser_session = browser_session
self.logger = logger or browser_session.logger
self.cross_origin_iframes = cross_origin_iframes
self.paint_order_filtering = paint_order_filtering
self.max_iframes = max_iframes
self.max_iframe_depth = max_iframe_depth
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
pass # no need to cleanup anything, browser_session auto handles cleaning up session cache
def _build_enhanced_ax_node(self, ax_node: AXNode) -> EnhancedAXNode:
properties: list[EnhancedAXProperty] | None = None
if 'properties' in ax_node and ax_node['properties']:
properties = []
for property in ax_node['properties']:
try:
# test whether property name can go into the enum (sometimes Chrome returns some random properties)
properties.append(
EnhancedAXProperty(
name=property['name'],
value=property.get('value', {}).get('value', None),
# related_nodes=[], # TODO: add related nodes
)
)
except ValueError:
pass
enhanced_ax_node = EnhancedAXNode(
ax_node_id=ax_node['nodeId'],
ignored=ax_node['ignored'],
role=ax_node.get('role', {}).get('value', None),
name=ax_node.get('name', {}).get('value', None),
description=ax_node.get('description', {}).get('value', None),
properties=properties,
child_ids=ax_node.get('childIds', []) if ax_node.get('childIds') else None,
)
return enhanced_ax_node
async def _get_viewport_ratio(self, target_id: TargetID) -> float:
"""Get viewport dimensions, device pixel ratio, and scroll position using CDP."""
cdp_session = await self.browser_session.get_or_create_cdp_session(target_id=target_id, focus=False)
try:
# Get the layout metrics which includes the visual viewport
metrics = await cdp_session.cdp_client.send.Page.getLayoutMetrics(session_id=cdp_session.session_id)
visual_viewport = metrics.get('visualViewport', {})
# IMPORTANT: Use CSS viewport instead of device pixel viewport
# This fixes the coordinate mismatch on high-DPI displays
css_visual_viewport = metrics.get('cssVisualViewport', {})
css_layout_viewport = metrics.get('cssLayoutViewport', {})
# Use CSS pixels (what JavaScript sees) instead of device pixels
width = css_visual_viewport.get('clientWidth', css_layout_viewport.get('clientWidth', 1920.0))
# Calculate device pixel ratio
device_width = visual_viewport.get('clientWidth', width)
css_width = css_visual_viewport.get('clientWidth', width)
device_pixel_ratio = device_width / css_width if css_width > 0 else 1.0
return float(device_pixel_ratio)
except Exception as e:
self.logger.debug(f'Viewport size detection failed: {e}')
# Fallback to default viewport size
return 1.0
@classmethod
def is_element_visible_according_to_all_parents(
cls, node: EnhancedDOMTreeNode, html_frames: list[EnhancedDOMTreeNode]
) -> bool:
"""Check if the element is visible according to all its parent HTML frames."""
if not node.snapshot_node:
return False
computed_styles = node.snapshot_node.computed_styles or {}
display = computed_styles.get('display', '').lower()
visibility = computed_styles.get('visibility', '').lower()
opacity = computed_styles.get('opacity', '1')
if display == 'none' or visibility == 'hidden':
return False
try:
if float(opacity) <= 0:
return False
except (ValueError, TypeError):
pass
# Start with the element's local bounds (in its own frame's coordinate system)
current_bounds = node.snapshot_node.bounds
if not current_bounds:
return False # If there are no bounds, the element is not visible
"""
Reverse iterate through the html frames (that can be either iframe or document -> if it's a document frame compare if the current bounds interest with it (taking scroll into account) otherwise move the current bounds by the iframe offset)
"""
for frame in reversed(html_frames):
if (
frame.node_type == NodeType.ELEMENT_NODE
and (frame.node_name.upper() == 'IFRAME' or frame.node_name.upper() == 'FRAME')
and frame.snapshot_node
and frame.snapshot_node.bounds
):
iframe_bounds = frame.snapshot_node.bounds
# negate the values added in `_construct_enhanced_node`
current_bounds.x += iframe_bounds.x
current_bounds.y += iframe_bounds.y
if (
frame.node_type == NodeType.ELEMENT_NODE
and frame.node_name == 'HTML'
and frame.snapshot_node
and frame.snapshot_node.scrollRects
and frame.snapshot_node.clientRects
):
# For iframe content, we need to check visibility within the iframe's viewport
# The scrollRects represent the current scroll position
# The clientRects represent the viewport size
# Elements are visible if they fall within the viewport after accounting for scroll
# The viewport of the frame (what's actually visible)
viewport_left = 0 # Viewport always starts at 0 in frame coordinates
viewport_top = 0
viewport_right = frame.snapshot_node.clientRects.width
viewport_bottom = frame.snapshot_node.clientRects.height
# Adjust element bounds by the scroll offset to get position relative to viewport
# When scrolled down, scrollRects.y is positive, so we subtract it from element's y
adjusted_x = current_bounds.x - frame.snapshot_node.scrollRects.x
adjusted_y = current_bounds.y - frame.snapshot_node.scrollRects.y
frame_intersects = (
adjusted_x < viewport_right
and adjusted_x + current_bounds.width > viewport_left
and adjusted_y < viewport_bottom + 1000
and adjusted_y + current_bounds.height > viewport_top - 1000
)
if not frame_intersects:
return False
# Keep the original coordinate adjustment to maintain consistency
# This adjustment is needed for proper coordinate transformation
current_bounds.x -= frame.snapshot_node.scrollRects.x
current_bounds.y -= frame.snapshot_node.scrollRects.y
# If we reach here, element is visible in main viewport and all containing iframes
return True
async def _get_ax_tree_for_all_frames(self, target_id: TargetID) -> GetFullAXTreeReturns:
"""Recursively collect all frames and merge their accessibility trees into a single array."""
cdp_session = await self.browser_session.get_or_create_cdp_session(target_id=target_id, focus=False)
frame_tree = await cdp_session.cdp_client.send.Page.getFrameTree(session_id=cdp_session.session_id)
def collect_all_frame_ids(frame_tree_node) -> list[str]:
"""Recursively collect all frame IDs from the frame tree."""
frame_ids = [frame_tree_node['frame']['id']]
if 'childFrames' in frame_tree_node and frame_tree_node['childFrames']:
for child_frame in frame_tree_node['childFrames']:
frame_ids.extend(collect_all_frame_ids(child_frame))
return frame_ids
# Collect all frame IDs recursively
all_frame_ids = collect_all_frame_ids(frame_tree['frameTree'])
# Get accessibility tree for each frame
ax_tree_requests = []
for frame_id in all_frame_ids:
ax_tree_request = cdp_session.cdp_client.send.Accessibility.getFullAXTree(
params={'frameId': frame_id}, session_id=cdp_session.session_id
)
ax_tree_requests.append(ax_tree_request)
# Wait for all requests to complete
ax_trees = await asyncio.gather(*ax_tree_requests)
# Merge all AX nodes into a single array
merged_nodes: list[AXNode] = []
for ax_tree in ax_trees:
merged_nodes.extend(ax_tree['nodes'])
return {'nodes': merged_nodes}
async def _get_all_trees(self, target_id: TargetID) -> TargetAllTrees:
cdp_session = await self.browser_session.get_or_create_cdp_session(target_id=target_id, focus=False)
# Wait for the page to be ready first
try:
ready_state = await cdp_session.cdp_client.send.Runtime.evaluate(
params={'expression': 'document.readyState'}, session_id=cdp_session.session_id
)
except Exception as e:
pass # Page might not be ready yet
# DEBUG: Log before capturing snapshot
self.logger.debug(f'🔍 DEBUG: Capturing DOM snapshot for target {target_id}')
# Get actual scroll positions for all iframes before capturing snapshot
start_iframe_scroll = time.time()
iframe_scroll_positions = {}
try:
scroll_result = await cdp_session.cdp_client.send.Runtime.evaluate(
params={
'expression': """
(() => {
const scrollData = {};
const iframes = document.querySelectorAll('iframe');
iframes.forEach((iframe, index) => {
try {
const doc = iframe.contentDocument || iframe.contentWindow.document;
if (doc) {
scrollData[index] = {
scrollTop: doc.documentElement.scrollTop || doc.body.scrollTop || 0,
scrollLeft: doc.documentElement.scrollLeft || doc.body.scrollLeft || 0
};
}
} catch (e) {
// Cross-origin iframe, can't access
}
});
return scrollData;
})()
""",
'returnByValue': True,
},
session_id=cdp_session.session_id,
)
if scroll_result and 'result' in scroll_result and 'value' in scroll_result['result']:
iframe_scroll_positions = scroll_result['result']['value']
for idx, scroll_data in iframe_scroll_positions.items():
self.logger.debug(
f'🔍 DEBUG: Iframe {idx} actual scroll position - scrollTop={scroll_data.get("scrollTop", 0)}, scrollLeft={scroll_data.get("scrollLeft", 0)}'
)
except Exception as e:
self.logger.debug(f'Failed to get iframe scroll positions: {e}')
iframe_scroll_ms = (time.time() - start_iframe_scroll) * 1000
# Define CDP request factories to avoid duplication
def create_snapshot_request():
return cdp_session.cdp_client.send.DOMSnapshot.captureSnapshot(
params={
'computedStyles': REQUIRED_COMPUTED_STYLES,
'includePaintOrder': True,
'includeDOMRects': True,
'includeBlendedBackgroundColors': False,
'includeTextColorOpacities': False,
},
session_id=cdp_session.session_id,
)
def create_dom_tree_request():
return cdp_session.cdp_client.send.DOM.getDocument(
params={'depth': -1, 'pierce': True}, session_id=cdp_session.session_id
)
start_cdp_calls = time.time()
# Create initial tasks
tasks = {
'snapshot': create_task_with_error_handling(create_snapshot_request(), name='get_snapshot'),
'dom_tree': create_task_with_error_handling(create_dom_tree_request(), name='get_dom_tree'),
'ax_tree': create_task_with_error_handling(self._get_ax_tree_for_all_frames(target_id), name='get_ax_tree'),
'device_pixel_ratio': create_task_with_error_handling(self._get_viewport_ratio(target_id), name='get_viewport_ratio'),
}
# Wait for all tasks with timeout
done, pending = await asyncio.wait(tasks.values(), timeout=10.0)
# Retry any failed or timed out tasks
if pending:
for task in pending:
task.cancel()
# Retry mapping for pending tasks
retry_map = {
tasks['snapshot']: lambda: create_task_with_error_handling(create_snapshot_request(), name='get_snapshot_retry'),
tasks['dom_tree']: lambda: create_task_with_error_handling(create_dom_tree_request(), name='get_dom_tree_retry'),
tasks['ax_tree']: lambda: create_task_with_error_handling(
self._get_ax_tree_for_all_frames(target_id), name='get_ax_tree_retry'
),
tasks['device_pixel_ratio']: lambda: create_task_with_error_handling(
self._get_viewport_ratio(target_id), name='get_viewport_ratio_retry'
),
}
# Create new tasks only for the ones that didn't complete
for key, task in tasks.items():
if task in pending and task in retry_map:
tasks[key] = retry_map[task]()
# Wait again with shorter timeout
done2, pending2 = await asyncio.wait([t for t in tasks.values() if not t.done()], timeout=2.0)
if pending2:
for task in pending2:
task.cancel()
# Extract results, tracking which ones failed
results = {}
failed = []
for key, task in tasks.items():
if task.done() and not task.cancelled():
try:
results[key] = task.result()
except Exception as e:
self.logger.warning(f'CDP request {key} failed with exception: {e}')
failed.append(key)
else:
self.logger.warning(f'CDP request {key} timed out')
failed.append(key)
# If any required tasks failed, raise an exception
if failed:
raise TimeoutError(f'CDP requests failed or timed out: {", ".join(failed)}')
snapshot = results['snapshot']
dom_tree = results['dom_tree']
ax_tree = results['ax_tree']
device_pixel_ratio = results['device_pixel_ratio']
end_cdp_calls = time.time()
cdp_calls_ms = (end_cdp_calls - start_cdp_calls) * 1000
# Calculate total time for _get_all_trees and overhead
start_snapshot_processing = time.time()
# DEBUG: Log snapshot info and limit documents to prevent explosion
if snapshot and 'documents' in snapshot:
original_doc_count = len(snapshot['documents'])
# Limit to max_iframes documents to prevent iframe explosion
if original_doc_count > self.max_iframes:
self.logger.warning(
f'⚠️ Limiting processing of {original_doc_count} iframes on page to only first {self.max_iframes} to prevent crashes!'
)
snapshot['documents'] = snapshot['documents'][: self.max_iframes]
total_nodes = sum(len(doc.get('nodes', [])) for doc in snapshot['documents'])
self.logger.debug(f'🔍 DEBUG: Snapshot contains {len(snapshot["documents"])} frames with {total_nodes} total nodes')
# Log iframe-specific info
for doc_idx, doc in enumerate(snapshot['documents']):
if doc_idx > 0: # Not the main document
self.logger.debug(
f'🔍 DEBUG: Iframe #{doc_idx} {doc.get("frameId", "no-frame-id")} {doc.get("url", "no-url")} has {len(doc.get("nodes", []))} nodes'
)
snapshot_processing_ms = (time.time() - start_snapshot_processing) * 1000
# Return with detailed timing breakdown
return TargetAllTrees(
snapshot=snapshot,
dom_tree=dom_tree,
ax_tree=ax_tree,
device_pixel_ratio=device_pixel_ratio,
cdp_timing={
'iframe_scroll_detection_ms': iframe_scroll_ms,
'cdp_parallel_calls_ms': cdp_calls_ms,
'snapshot_processing_ms': snapshot_processing_ms,
},
)
@observe_debug(ignore_input=True, ignore_output=True, name='get_dom_tree')
async def get_dom_tree(
self,
target_id: TargetID,
all_frames: dict | None = None,
initial_html_frames: list[EnhancedDOMTreeNode] | None = None,
initial_total_frame_offset: DOMRect | None = None,
iframe_depth: int = 0,
) -> tuple[EnhancedDOMTreeNode, dict[str, float]]:
"""Get the DOM tree for a specific target.
Args:
target_id: Target ID of the page to get the DOM tree for.
all_frames: Pre-fetched frame hierarchy to avoid redundant CDP calls (optional, lazy fetch if None)
initial_html_frames: List of HTML frame nodes encountered so far
initial_total_frame_offset: Accumulated coordinate offset
iframe_depth: Current depth of iframe nesting to prevent infinite recursion
Returns:
Tuple of (enhanced_dom_tree_node, timing_info)
"""
timing_info: dict[str, float] = {}
timing_start_total = time.time()
# Get all trees from CDP (snapshot, DOM, AX, viewport ratio)
start_get_trees = time.time()
trees = await self._get_all_trees(target_id)
get_trees_ms = (time.time() - start_get_trees) * 1000
timing_info.update(trees.cdp_timing)
timing_info['get_all_trees_total_ms'] = get_trees_ms
dom_tree = trees.dom_tree
ax_tree = trees.ax_tree
snapshot = trees.snapshot
device_pixel_ratio = trees.device_pixel_ratio
# Build AX tree lookup
start_ax = time.time()
ax_tree_lookup: dict[int, AXNode] = {
ax_node['backendDOMNodeId']: ax_node for ax_node in ax_tree['nodes'] if 'backendDOMNodeId' in ax_node
}
timing_info['build_ax_lookup_ms'] = (time.time() - start_ax) * 1000
enhanced_dom_tree_node_lookup: dict[int, EnhancedDOMTreeNode] = {}
""" NodeId (NOT backend node id) -> enhanced dom tree node""" # way to get the parent/content node
# Parse snapshot data with everything calculated upfront
start_snapshot = time.time()
snapshot_lookup = build_snapshot_lookup(snapshot, device_pixel_ratio)
timing_info['build_snapshot_lookup_ms'] = (time.time() - start_snapshot) * 1000
async def _construct_enhanced_node(
node: Node,
html_frames: list[EnhancedDOMTreeNode] | None,
total_frame_offset: DOMRect | None,
all_frames: dict | None,
) -> EnhancedDOMTreeNode:
"""
Recursively construct enhanced DOM tree nodes.
Args:
node: The DOM node to construct
html_frames: List of HTML frame nodes encountered so far
total_frame_offset: Accumulated coordinate translation from parent iframes (includes scroll corrections)
all_frames: Pre-fetched frame hierarchy to avoid redundant CDP calls
"""
# Initialize lists if not provided
if html_frames is None:
html_frames = []
# to get rid of the pointer references
if total_frame_offset is None:
total_frame_offset = DOMRect(x=0.0, y=0.0, width=0.0, height=0.0)
else:
total_frame_offset = DOMRect(
total_frame_offset.x, total_frame_offset.y, total_frame_offset.width, total_frame_offset.height
)
# memoize the mf (I don't know if some nodes are duplicated)
if node['nodeId'] in enhanced_dom_tree_node_lookup:
return enhanced_dom_tree_node_lookup[node['nodeId']]
ax_node = ax_tree_lookup.get(node['backendNodeId'])
if ax_node:
enhanced_ax_node = self._build_enhanced_ax_node(ax_node)
else:
enhanced_ax_node = None
# To make attributes more readable
attributes: dict[str, str] | None = None
if 'attributes' in node and node['attributes']:
attributes = {}
for i in range(0, len(node['attributes']), 2):
attributes[node['attributes'][i]] = node['attributes'][i + 1]
shadow_root_type = None
if 'shadowRootType' in node and node['shadowRootType']:
try:
shadow_root_type = node['shadowRootType']
except ValueError:
pass
# Get snapshot data and calculate absolute position
snapshot_data = snapshot_lookup.get(node['backendNodeId'], None)
# DIAGNOSTIC: Log when interactive elements don't have snapshot data
if not snapshot_data and node['nodeName'].upper() in ['INPUT', 'BUTTON', 'SELECT', 'TEXTAREA', 'A']:
parent_has_shadow = False
parent_info = ''
if 'parentId' in node and node['parentId'] in enhanced_dom_tree_node_lookup:
parent = enhanced_dom_tree_node_lookup[node['parentId']]
if parent.shadow_root_type:
parent_has_shadow = True
parent_info = f'parent={parent.tag_name}(shadow={parent.shadow_root_type})'
attr_str = ''
if 'attributes' in node and node['attributes']:
attrs_dict = {node['attributes'][i]: node['attributes'][i + 1] for i in range(0, len(node['attributes']), 2)}
attr_str = f'name={attrs_dict.get("name", "N/A")} id={attrs_dict.get("id", "N/A")}'
self.logger.debug(
f'🔍 NO SNAPSHOT DATA for <{node["nodeName"]}> backendNodeId={node["backendNodeId"]} '
f'{attr_str} {parent_info} (snapshot_lookup has {len(snapshot_lookup)} entries)'
)
absolute_position = None
if snapshot_data and snapshot_data.bounds:
absolute_position = DOMRect(
x=snapshot_data.bounds.x + total_frame_offset.x,
y=snapshot_data.bounds.y + total_frame_offset.y,
width=snapshot_data.bounds.width,
height=snapshot_data.bounds.height,
)
try:
session = await self.browser_session.get_or_create_cdp_session(target_id, focus=False)
session_id = session.session_id
except ValueError:
# Target may have detached during DOM construction
session_id = None
dom_tree_node = EnhancedDOMTreeNode(
node_id=node['nodeId'],
backend_node_id=node['backendNodeId'],
node_type=NodeType(node['nodeType']),
node_name=node['nodeName'],
node_value=node['nodeValue'],
attributes=attributes or {},
is_scrollable=node.get('isScrollable', None),
frame_id=node.get('frameId', None),
session_id=session_id,
target_id=target_id,
content_document=None,
shadow_root_type=shadow_root_type,
shadow_roots=None,
parent_node=None,
children_nodes=None,
ax_node=enhanced_ax_node,
snapshot_node=snapshot_data,
is_visible=None,
absolute_position=absolute_position,
)
enhanced_dom_tree_node_lookup[node['nodeId']] = dom_tree_node
if 'parentId' in node and node['parentId']:
dom_tree_node.parent_node = enhanced_dom_tree_node_lookup[
node['parentId']
] # parents should always be in the lookup
# Check if this is an HTML frame node and add it to the list
updated_html_frames = html_frames.copy()
if node['nodeType'] == NodeType.ELEMENT_NODE.value and node['nodeName'] == 'HTML' and node.get('frameId') is not None:
updated_html_frames.append(dom_tree_node)
# and adjust the total frame offset by scroll
if snapshot_data and snapshot_data.scrollRects:
total_frame_offset.x -= snapshot_data.scrollRects.x
total_frame_offset.y -= snapshot_data.scrollRects.y
# DEBUG: Log iframe scroll information
self.logger.debug(
f'🔍 DEBUG: HTML frame scroll - scrollY={snapshot_data.scrollRects.y}, scrollX={snapshot_data.scrollRects.x}, frameId={node.get("frameId")}, nodeId={node["nodeId"]}'
)
# Calculate new iframe offset for content documents, accounting for iframe scroll
if (
(node['nodeName'].upper() == 'IFRAME' or node['nodeName'].upper() == 'FRAME')
and snapshot_data
and snapshot_data.bounds
):
if snapshot_data.bounds:
updated_html_frames.append(dom_tree_node)
total_frame_offset.x += snapshot_data.bounds.x
total_frame_offset.y += snapshot_data.bounds.y
if 'contentDocument' in node and node['contentDocument']:
dom_tree_node.content_document = await _construct_enhanced_node(
node['contentDocument'], updated_html_frames, total_frame_offset, all_frames
)
dom_tree_node.content_document.parent_node = dom_tree_node
# forcefully set the parent node to the content document node (helps traverse the tree)
if 'shadowRoots' in node and node['shadowRoots']:
dom_tree_node.shadow_roots = []
for shadow_root in node['shadowRoots']:
shadow_root_node = await _construct_enhanced_node(
shadow_root, updated_html_frames, total_frame_offset, all_frames
)
# forcefully set the parent node to the shadow root node (helps traverse the tree)
shadow_root_node.parent_node = dom_tree_node
dom_tree_node.shadow_roots.append(shadow_root_node)
if 'children' in node and node['children']:
dom_tree_node.children_nodes = []
# Build set of shadow root node IDs to filter them out from children
shadow_root_node_ids = set()
if 'shadowRoots' in node and node['shadowRoots']:
for shadow_root in node['shadowRoots']:
shadow_root_node_ids.add(shadow_root['nodeId'])
for child in node['children']:
# Skip shadow roots - they should only be in shadow_roots list
if child['nodeId'] in shadow_root_node_ids:
continue
dom_tree_node.children_nodes.append(
await _construct_enhanced_node(child, updated_html_frames, total_frame_offset, all_frames)
)
# Set visibility using the collected HTML frames
dom_tree_node.is_visible = self.is_element_visible_according_to_all_parents(dom_tree_node, updated_html_frames)
# DEBUG: Log visibility info for form elements in iframes
if dom_tree_node.tag_name and dom_tree_node.tag_name.upper() in ['INPUT', 'SELECT', 'TEXTAREA', 'LABEL']:
attrs = dom_tree_node.attributes or {}
elem_id = attrs.get('id', '')
elem_name = attrs.get('name', '')
if (
'city' in elem_id.lower()
or 'city' in elem_name.lower()
or 'state' in elem_id.lower()
or 'state' in elem_name.lower()
or 'zip' in elem_id.lower()
or 'zip' in elem_name.lower()
):
self.logger.debug(
f"🔍 DEBUG: Form element {dom_tree_node.tag_name} id='{elem_id}' name='{elem_name}' - visible={dom_tree_node.is_visible}, bounds={dom_tree_node.snapshot_node.bounds if dom_tree_node.snapshot_node else 'NO_SNAPSHOT'}"
)
# handle cross origin iframe (just recursively call the main function with the proper target if it exists in iframes)
# only do this if the iframe is visible (otherwise it's not worth it)
if (
# TODO: hacky way to disable cross origin iframes for now
self.cross_origin_iframes and node['nodeName'].upper() == 'IFRAME' and node.get('contentDocument', None) is None
): # None meaning there is no content
# Check iframe depth to prevent infinite recursion
if iframe_depth >= self.max_iframe_depth:
self.logger.debug(
f'Skipping iframe at depth {iframe_depth} to prevent infinite recursion (max depth: {self.max_iframe_depth})'
)
else:
# Check if iframe is visible and large enough (>= 50px in both dimensions)
should_process_iframe = False
# First check if the iframe element itself is visible
if dom_tree_node.is_visible:
# Check iframe dimensions
if dom_tree_node.snapshot_node and dom_tree_node.snapshot_node.bounds:
bounds = dom_tree_node.snapshot_node.bounds
width = bounds.width
height = bounds.height
# Only process if iframe is at least 50px in both dimensions
if width >= 50 and height >= 50:
should_process_iframe = True
self.logger.debug(f'Processing cross-origin iframe: visible=True, width={width}, height={height}')
else:
self.logger.debug(
f'Skipping small cross-origin iframe: width={width}, height={height} (needs >= 50px)'
)
else:
self.logger.debug('Skipping cross-origin iframe: no bounds available')
else:
self.logger.debug('Skipping invisible cross-origin iframe')
if should_process_iframe:
# Lazy fetch all_frames only when actually needed (for cross-origin iframes)
if all_frames is None:
all_frames, _ = await self.browser_session.get_all_frames()
# Use pre-fetched all_frames to find the iframe's target (no redundant CDP call)
frame_id = node.get('frameId', None)
if frame_id:
frame_info = all_frames.get(frame_id)
iframe_document_target = None
if frame_info and frame_info.get('frameTargetId'):
iframe_target_id = frame_info['frameTargetId']
iframe_target = self.browser_session.session_manager.get_target(iframe_target_id)
if iframe_target:
iframe_document_target = {
'targetId': iframe_target.target_id,
'url': iframe_target.url,
'title': iframe_target.title,
'type': iframe_target.target_type,
}
else:
iframe_document_target = None
# if target actually exists in one of the frames, just recursively build the dom tree for it
if iframe_document_target:
self.logger.debug(
f'Getting content document for iframe {node.get("frameId", None)} at depth {iframe_depth + 1}'
)
content_document, _ = await self.get_dom_tree(
target_id=iframe_document_target['targetId'],
all_frames=all_frames,
# TODO: experiment with this values -> not sure whether the whole cross origin iframe should be ALWAYS included as soon as some part of it is visible or not.
# Current config: if the cross origin iframe is AT ALL visible, then just include everything inside of it!
# initial_html_frames=updated_html_frames,
initial_total_frame_offset=total_frame_offset,
iframe_depth=iframe_depth + 1,
)
dom_tree_node.content_document = content_document
dom_tree_node.content_document.parent_node = dom_tree_node
return dom_tree_node
# Build enhanced DOM tree recursively
# Note: all_frames stays None and will be lazily fetched inside _construct_enhanced_node
# only if/when a cross-origin iframe is encountered
start_construct = time.time()
enhanced_dom_tree_node = await _construct_enhanced_node(
dom_tree['root'], initial_html_frames, initial_total_frame_offset, all_frames
)
timing_info['construct_enhanced_tree_ms'] = (time.time() - start_construct) * 1000
# Calculate total time for get_dom_tree
total_get_dom_tree_ms = (time.time() - timing_start_total) * 1000
timing_info['get_dom_tree_total_ms'] = total_get_dom_tree_ms
# Calculate overhead in get_dom_tree (time not accounted for by sub-operations)
tracked_sub_operations_ms = (
timing_info.get('get_all_trees_total_ms', 0)
+ timing_info.get('build_ax_lookup_ms', 0)
+ timing_info.get('build_snapshot_lookup_ms', 0)
+ timing_info.get('construct_enhanced_tree_ms', 0)
)
get_dom_tree_overhead_ms = total_get_dom_tree_ms - tracked_sub_operations_ms
if get_dom_tree_overhead_ms > 0.1:
timing_info['get_dom_tree_overhead_ms'] = get_dom_tree_overhead_ms
return enhanced_dom_tree_node, timing_info
@observe_debug(ignore_input=True, ignore_output=True, name='get_serialized_dom_tree')
async def get_serialized_dom_tree(
self, previous_cached_state: SerializedDOMState | None = None
) -> tuple[SerializedDOMState, EnhancedDOMTreeNode, dict[str, float]]:
"""Get the serialized DOM tree representation for LLM consumption.
Returns:
Tuple of (serialized_dom_state, enhanced_dom_tree_root, timing_info)
"""
timing_info: dict[str, float] = {}
start_total = time.time()
# Use current target (None means use current)
assert self.browser_session.agent_focus_target_id is not None
session_id = self.browser_session.id
# Build DOM tree (includes CDP calls for snapshot, DOM, AX tree)
# Note: all_frames is fetched lazily inside get_dom_tree only if cross-origin iframes need it
enhanced_dom_tree, dom_tree_timing = await self.get_dom_tree(
target_id=self.browser_session.agent_focus_target_id,
all_frames=None, # Lazy - will fetch if needed
)
# Add sub-timings from DOM tree construction
timing_info.update(dom_tree_timing)
# Serialize DOM tree for LLM
start_serialize = time.time()
serialized_dom_state, serializer_timing = DOMTreeSerializer(
enhanced_dom_tree, previous_cached_state, paint_order_filtering=self.paint_order_filtering, session_id=session_id
).serialize_accessible_elements()
total_serialization_ms = (time.time() - start_serialize) * 1000
# Add serializer sub-timings (convert to ms)
for key, value in serializer_timing.items():
timing_info[f'{key}_ms'] = value * 1000
# Calculate untracked time in serialization
tracked_serialization_ms = sum(value * 1000 for value in serializer_timing.values())
serialization_overhead_ms = total_serialization_ms - tracked_serialization_ms
if serialization_overhead_ms > 0.1: # Only log if significant
timing_info['serialization_overhead_ms'] = serialization_overhead_ms
# Calculate total time for get_serialized_dom_tree
total_get_serialized_dom_tree_ms = (time.time() - start_total) * 1000
timing_info['get_serialized_dom_tree_total_ms'] = total_get_serialized_dom_tree_ms
# Calculate overhead in get_serialized_dom_tree (time not accounted for)
tracked_major_operations_ms = timing_info.get('get_dom_tree_total_ms', 0) + total_serialization_ms
get_serialized_overhead_ms = total_get_serialized_dom_tree_ms - tracked_major_operations_ms
if get_serialized_overhead_ms > 0.1:
timing_info['get_serialized_dom_tree_overhead_ms'] = get_serialized_overhead_ms
return serialized_dom_state, enhanced_dom_tree, timing_info
@staticmethod
def detect_pagination_buttons(selector_map: dict[int, EnhancedDOMTreeNode]) -> list[dict[str, str | int | bool]]:
"""Detect pagination buttons from the selector map.
Args:
selector_map: Map of element indices to EnhancedDOMTreeNode
Returns:
List of pagination button information dicts with:
- button_type: 'next', 'prev', 'first', 'last', 'page_number'
- backend_node_id: Backend node ID for clicking
- text: Button text/label
- selector: XPath selector
- is_disabled: Whether the button appears disabled
"""
pagination_buttons: list[dict[str, str | int | bool]] = []
# Common pagination patterns to look for
next_patterns = ['next', '>', '»', '→', 'siguiente', 'suivant', 'weiter', 'volgende']
prev_patterns = ['prev', 'previous', '<', '«', '←', 'anterior', 'précédent', 'zurück', 'vorige']
first_patterns = ['first', '⇤', '«', 'primera', 'première', 'erste', 'eerste']
last_patterns = ['last', '⇥', '»', 'última', 'dernier', 'letzte', 'laatste']
for index, node in selector_map.items():
# Skip non-clickable elements
if not node.snapshot_node or not node.snapshot_node.is_clickable:
continue
# Get element text and attributes
text = node.get_all_children_text().lower().strip()
aria_label = node.attributes.get('aria-label', '').lower()
title = node.attributes.get('title', '').lower()
class_name = node.attributes.get('class', '').lower()
role = node.attributes.get('role', '').lower()
# Combine all text sources for pattern matching
all_text = f'{text} {aria_label} {title} {class_name}'.strip()
# Check if it's disabled
is_disabled = (
node.attributes.get('disabled') == 'true'
or node.attributes.get('aria-disabled') == 'true'
or 'disabled' in class_name
)
button_type: str | None = None
# Check for next button
if any(pattern in all_text for pattern in next_patterns):
button_type = 'next'
# Check for previous button
elif any(pattern in all_text for pattern in prev_patterns):
button_type = 'prev'
# Check for first button
elif any(pattern in all_text for pattern in first_patterns):
button_type = 'first'
# Check for last button
elif any(pattern in all_text for pattern in last_patterns):
button_type = 'last'
# Check for numeric page buttons (single or double digit)
elif text.isdigit() and len(text) <= 2 and role in ['button', 'link', '']:
button_type = 'page_number'
if button_type:
pagination_buttons.append(
{
'button_type': button_type,
'backend_node_id': index,
'text': node.get_all_children_text().strip() or aria_label or title,
'selector': node.xpath,
'is_disabled': is_disabled,
}
)
return pagination_buttons