import contextlib
from types import SimpleNamespace
import pytest
from config import Settings
from db import DatabaseError, PostgresConnector
class FakeCursor:
def __init__(self, rows):
self.rows = rows
self.executed = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, sql, params=None):
self.executed.append((sql, params))
def fetchall(self):
if not self.executed:
raise RuntimeError("execute must be called before fetchall")
return self.rows
class FakeConnection:
def __init__(self, cursor):
self.cursor_obj = cursor
self.closed = False
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.closed = True
return False
def cursor(self, row_factory=None):
# row_factory is unused in the fake; keep signature for compatibility.
return self.cursor_obj
@contextlib.contextmanager
def fake_connection(rows=None, cursor=None):
cur = cursor or FakeCursor(rows or [])
conn = FakeConnection(cur)
yield conn
def make_connector(monkeypatch, rows=None):
settings = Settings(
db_address="localhost",
db_port=5432,
db_name="test",
db_user="user",
db_password="pass",
)
connector = PostgresConnector(settings)
def _conn():
@contextlib.contextmanager
def cm():
cur = FakeCursor(rows or [])
connector._last_cursor = cur # type: ignore[attr-defined]
yield FakeConnection(cur)
return cm()
monkeypatch.setattr(connector, "_connection", _conn)
return connector
def test_run_read_query_adds_limit_and_sets_read_only(monkeypatch):
connector = make_connector(monkeypatch, rows=[{"x": 1}])
result = connector.run_read_query("select * from demo", limit=5)
# Check returned rows/rowcount
assert result.rows == [{"x": 1}]
assert result.rowcount == 1
# Inspect executed SQL
cursor = connector._last_cursor # type: ignore[attr-defined]
assert len(cursor.executed) == 2
set_sql, set_params = cursor.executed[0]
query_sql, query_params = cursor.executed[1]
assert set_sql.lower().startswith("set local default_transaction_read_only")
assert set_params is None
assert query_params is None
assert query_sql.strip().endswith("limit 5") or query_sql.strip().endswith("LIMIT 5")
def test_run_read_query_respects_existing_limit(monkeypatch):
connector = make_connector(monkeypatch, rows=[])
connector.run_read_query("select * from demo limit 10", limit=5)
cursor = connector._last_cursor # type: ignore[attr-defined]
query_sql, _ = cursor.executed[1]
assert "limit 10" in query_sql.lower()
assert "limit 5" not in query_sql.lower()
def test_run_read_query_passes_params(monkeypatch):
connector = make_connector(monkeypatch, rows=[])
connector.run_read_query("select * from demo where id=%(id)s", params={"id": 123}, limit=10)
cursor = connector._last_cursor # type: ignore[attr-defined]
executed_sql, executed_params = cursor.executed[1]
assert executed_params == {"id": 123}
assert "where id=%" in executed_sql.lower()
def test_run_read_query_blocks_write_like_statements(monkeypatch):
connector = make_connector(monkeypatch, rows=[])
with pytest.raises(DatabaseError):
connector.run_read_query("update demo set x=1", limit=10)