"""
Unitary Check - Verifies if circuit operations are properly unitary.
PhD-level validation for checking quantum coherence.
"""
from typing import Any
import numpy as np
from ..creation.gate_library import GateLibrary
def check_unitarity(circuit_data: dict[str, Any]) -> dict[str, Any]:
"""
Verify that the circuit represents a valid unitary operation.
Checks:
1. All gates are unitary (U†U = I)
2. No measurements break coherence in the middle
3. Circuit matrix is unitary overall
Args:
circuit_data: Circuit dictionary
Returns:
Validation result with unitarity analysis
"""
errors: list[str] = []
warnings: list[str] = []
num_qubits = circuit_data.get("num_qubits", 0)
gates = circuit_data.get("gates", [])
if num_qubits > 10:
warnings.append(
f"Circuit has {num_qubits} qubits - full unitary calculation would require "
f"{2**num_qubits}x{2**num_qubits} matrix (skipping full simulation)"
)
# Just check individual gates
return _check_gate_unitarity(gates)
# Track measurement positions
measurement_positions: list[int] = []
gate_after_measure: dict[int, list[int]] = {} # qubit -> [gate indices]
measured_qubits: set[int] = set()
for idx, gate in enumerate(gates):
name = gate.get("name", "").lower()
qubits = gate.get("qubits", [])
if name == "measure":
measurement_positions.append(idx)
for q in qubits:
measured_qubits.add(q)
elif name not in ("barrier", "reset"):
# Check if operating on measured qubit
for q in qubits:
if q in measured_qubits:
if q not in gate_after_measure:
gate_after_measure[q] = []
gate_after_measure[q].append(idx)
if gate_after_measure:
for q, indices in gate_after_measure.items():
warnings.append(
f"Qubit {q} has {len(indices)} gate(s) after measurement - "
f"circuit is not purely unitary"
)
# Check if all gates have valid unitary representations
gate_check_result = _check_gate_unitarity(gates)
errors.extend(gate_check_result.get("errors", []))
# Try to compute full unitary (for small circuits)
unitary_result = None
if num_qubits <= 6 and not measurement_positions:
try:
unitary_result = _compute_circuit_unitary(gates, num_qubits)
if not _is_unitary(unitary_result["matrix"]):
errors.append("Computed circuit matrix is not unitary")
except Exception as e:
warnings.append(f"Could not compute full unitary: {str(e)}")
is_unitary = len(errors) == 0 and not gate_after_measure
result: dict[str, Any] = {
"is_unitary": is_unitary,
"errors": errors,
"warnings": warnings,
"measurements_found": len(measurement_positions),
"gates_after_measurement": bool(gate_after_measure),
}
if unitary_result:
result["unitary_dimensions"] = unitary_result["dimensions"]
result["unitary_norm"] = unitary_result.get("norm", 1.0)
return result
def _check_gate_unitarity(gates: list[dict[str, Any]]) -> dict[str, Any]:
"""Check that each gate is a valid unitary."""
errors: list[str] = []
for idx, gate in enumerate(gates):
name = gate.get("name", "").lower()
params = gate.get("params", [])
if name in ("measure", "reset", "barrier"):
continue
try:
# Handle parameterized gates with symbolic params
if any(isinstance(p, str) and p.startswith("param:") for p in params):
# Can't check unitarity of symbolic parameters
continue
matrix = GateLibrary.get_gate(name, params if params else None)
if not _is_unitary(matrix):
errors.append(f"Gate {idx} ({name}): matrix is not unitary")
except ValueError:
# Unknown gate - already caught by syntax checker
pass
except Exception as e:
errors.append(f"Gate {idx} ({name}): could not verify unitarity - {str(e)}")
return {"errors": errors}
def _is_unitary(matrix: np.ndarray, tolerance: float = 1e-10) -> bool:
"""Check if a matrix is unitary (U†U = I)."""
n = matrix.shape[0]
product = matrix.conj().T @ matrix
identity = np.eye(n, dtype=complex)
return np.allclose(product, identity, atol=tolerance)
def compute_circuit_unitary(
circuit_data: dict[str, Any]
) -> dict[str, Any]:
"""
Compute the full unitary matrix for a circuit.
Public API wrapper for internal _compute_circuit_unitary function.
Args:
circuit_data: Circuit dictionary with gates and num_qubits
Returns:
Dictionary with unitary matrix and dimensions
"""
gates = circuit_data.get("gates", [])
num_qubits = circuit_data.get("num_qubits", 0)
return _compute_circuit_unitary(gates, num_qubits)
def _compute_circuit_unitary(
gates: list[dict[str, Any]],
num_qubits: int
) -> dict[str, Any]:
"""Compute the full unitary matrix for the circuit."""
dim = 2 ** num_qubits
unitary = np.eye(dim, dtype=complex)
for gate in gates:
name = gate.get("name", "").lower()
qubits = gate.get("qubits", [])
params = gate.get("params", [])
if name in ("measure", "reset", "barrier"):
continue
# Skip symbolic parameters
if any(isinstance(p, str) and p.startswith("param:") for p in params):
continue
try:
gate_matrix = GateLibrary.get_gate(name, params if params else None)
full_gate = _expand_gate(gate_matrix, qubits, num_qubits)
unitary = full_gate @ unitary
except Exception:
pass
return {
"matrix": unitary,
"dimensions": (dim, dim),
"norm": np.linalg.norm(unitary),
}
def _expand_gate(
gate_matrix: np.ndarray,
target_qubits: list[int],
num_qubits: int
) -> np.ndarray:
"""Expand a gate matrix to the full Hilbert space."""
dim = 2 ** num_qubits
if len(target_qubits) == 1:
q = target_qubits[0]
# Build tensor product: I ⊗ ... ⊗ G ⊗ ... ⊗ I
result = np.array([[1]], dtype=complex)
for i in range(num_qubits):
if i == q:
result = np.kron(result, gate_matrix)
else:
result = np.kron(result, np.eye(2, dtype=complex))
return result
elif len(target_qubits) == 2:
q0, q1 = target_qubits
# For 2-qubit gates, need to handle qubit ordering
# This is simplified - full implementation would handle arbitrary orderings
if q1 == q0 + 1:
# Adjacent qubits - simple case
result = np.array([[1]], dtype=complex)
i = 0
while i < num_qubits:
if i == q0:
result = np.kron(result, gate_matrix)
i += 2
else:
result = np.kron(result, np.eye(2, dtype=complex))
i += 1
return result
else:
# Non-adjacent - need SWAP network (simplified)
# For now, return identity and log warning
return np.eye(dim, dtype=complex)
else:
# Multi-qubit gate - complex expansion
return np.eye(dim, dtype=complex)
def analyze_entanglement_structure(circuit_data: dict[str, Any]) -> dict[str, Any]:
"""
Analyze the entanglement structure of the circuit.
Returns information about which qubits are entangled and where.
"""
gates = circuit_data.get("gates", [])
num_qubits = circuit_data.get("num_qubits", 0)
# Track entangling gates
entangling_gates = []
qubit_pairs: set[tuple[int, int]] = set()
entangling_gate_names = {"cx", "cnot", "cy", "cz", "swap", "iswap", "cp", "crx", "cry", "crz"}
three_qubit_entangling = {"ccx", "toffoli", "cswap", "fredkin"}
for idx, gate in enumerate(gates):
name = gate.get("name", "").lower()
qubits = gate.get("qubits", [])
if name in entangling_gate_names and len(qubits) == 2:
q0, q1 = qubits
entangling_gates.append({
"gate_idx": idx,
"gate": name,
"qubits": [q0, q1]
})
qubit_pairs.add((min(q0, q1), max(q0, q1)))
elif name in three_qubit_entangling and len(qubits) == 3:
entangling_gates.append({
"gate_idx": idx,
"gate": name,
"qubits": qubits
})
for i in range(len(qubits)):
for j in range(i + 1, len(qubits)):
qubit_pairs.add((min(qubits[i], qubits[j]), max(qubits[i], qubits[j])))
# Build entanglement graph
entanglement_graph: dict[int, list[int]] = {i: [] for i in range(num_qubits)}
for q0, q1 in qubit_pairs:
entanglement_graph[q0].append(q1)
entanglement_graph[q1].append(q0)
# Find entanglement clusters (connected components)
visited: set[int] = set()
clusters: list[list[int]] = []
def dfs(node: int, cluster: list[int]) -> None:
if node in visited:
return
visited.add(node)
cluster.append(node)
for neighbor in entanglement_graph[node]:
dfs(neighbor, cluster)
for q in range(num_qubits):
if q not in visited and entanglement_graph[q]:
cluster: list[int] = []
dfs(q, cluster)
if len(cluster) > 1:
clusters.append(sorted(cluster))
# Add isolated qubits
isolated = [q for q in range(num_qubits) if q not in visited]
return {
"entangling_gate_count": len(entangling_gates),
"entangling_gates": entangling_gates,
"entangled_pairs": [list(p) for p in sorted(qubit_pairs)],
"entanglement_clusters": clusters,
"isolated_qubits": isolated,
"max_entanglement_depth": _calculate_entanglement_depth(entangling_gates),
}
def _calculate_entanglement_depth(entangling_gates: list[dict[str, Any]]) -> int:
"""Calculate the depth of entangling gates only."""
if not entangling_gates:
return 0
# Track depth per qubit
qubit_depth: dict[int, int] = {}
for gate in entangling_gates:
qubits = gate.get("qubits", [])
max_depth = max((qubit_depth.get(q, 0) for q in qubits), default=0)
for q in qubits:
qubit_depth[q] = max_depth + 1
return max(qubit_depth.values()) if qubit_depth else 0