test_google_vertex.py•7.2 kB
# pyright: reportDeprecated=false
from __future__ import annotations as _annotations
import json
import os
from dataclasses import dataclass
from pathlib import Path
from unittest.mock import patch
import httpx
import pytest
from inline_snapshot import snapshot
from pytest_mock import MockerFixture
from pydantic_ai.agent import Agent
from pydantic_ai.models.gemini import GeminiModel
from ..conftest import try_import
with try_import() as imports_successful:
    from google.auth.transport.requests import Request
    from pydantic_ai.providers.google_vertex import GoogleVertexProvider
pytestmark = [
    pytest.mark.skipif(not imports_successful(), reason='google-genai not installed'),
    pytest.mark.anyio(),
    pytest.mark.filterwarnings('ignore:Use `GoogleModel` instead.:DeprecationWarning'),
    pytest.mark.filterwarnings('ignore:`GoogleVertexProvider` is deprecated.:DeprecationWarning'),
]
@pytest.fixture()
def http_client():
    async def handler(request: httpx.Request):
        if (
            request.url.path
            == '/v1/projects/my-project-id/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent'
        ):
            return httpx.Response(200, json={'content': 'success'})
        raise NotImplementedError(f'Unexpected request: {request.url!r}')  # pragma: no cover
    return httpx.AsyncClient(transport=httpx.MockTransport(handler=handler))
def test_google_vertex_provider(allow_model_requests: None) -> None:
    provider = GoogleVertexProvider()
    assert provider.name == 'google-vertex'
    assert provider.base_url == snapshot(
        'https://us-central1-aiplatform.googleapis.com/v1/projects/None/locations/us-central1/publishers/google/models/'
    )
    assert isinstance(provider.client, httpx.AsyncClient)
@dataclass
class NoOpCredentials:
    token = 'my-token'
    def refresh(self, request: Request): ...  # pragma: no branch
@patch('pydantic_ai.providers.google_vertex.google.auth.default', return_value=(NoOpCredentials(), 'my-project-id'))
async def test_google_vertex_provider_auth(allow_model_requests: None, http_client: httpx.AsyncClient):
    provider = GoogleVertexProvider(http_client=http_client)
    await provider.client.post('/gemini-1.0-pro:generateContent')
    assert provider.region == 'us-central1'
    assert getattr(provider.client.auth, 'project_id') == 'my-project-id'
async def mock_refresh_token():
    return 'my-token'
async def test_google_vertex_provider_service_account_file(
    monkeypatch: pytest.MonkeyPatch, tmp_path: Path, allow_model_requests: None
):
    service_account_path = tmp_path / 'service_account.json'
    save_service_account(service_account_path, 'my-project-id')
    provider = GoogleVertexProvider(service_account_file=service_account_path)
    monkeypatch.setattr(provider.client.auth, '_refresh_token', mock_refresh_token)
    await provider.client.post('/gemini-1.0-pro:generateContent')
    assert provider.region == 'us-central1'
    assert getattr(provider.client.auth, 'project_id') == 'my-project-id'
async def test_google_vertex_provider_service_account_file_info(
    monkeypatch: pytest.MonkeyPatch, allow_model_requests: None
):
    account_info = prepare_service_account_contents('my-project-id')
    provider = GoogleVertexProvider(service_account_info=account_info)
    monkeypatch.setattr(provider.client.auth, '_refresh_token', mock_refresh_token)
    await provider.client.post('/gemini-1.0-pro:generateContent')
    assert provider.region == 'us-central1'
    assert getattr(provider.client.auth, 'project_id') == 'my-project-id'
async def test_google_vertex_provider_service_account_xor(allow_model_requests: None):
    with pytest.raises(
        ValueError, match='Only one of `service_account_file` or `service_account_info` can be provided'
    ):
        GoogleVertexProvider(  # type: ignore[reportCallIssue]
            service_account_file='path/to/service-account.json',
            service_account_info=prepare_service_account_contents('my-project-id'),
        )
def prepare_service_account_contents(project_id: str) -> dict[str, str]:
    return {
        'type': 'service_account',
        'project_id': project_id,
        'private_key_id': 'abc',
        # this is just a random private key I created with `openssl genpke ...`, it doesn't do anything
        'private_key': (
            '-----BEGIN PRIVATE KEY-----\n'
            'MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMFrZYX4gZ20qv88\n'
            'jD0QCswXgcxgP7Ta06G47QEFprDVcv4WMUBDJVAKofzVcYyhsasWsOSxcpA8LIi9\n'
            '/VS2Otf8CmIK6nPBCD17Qgt8/IQYXOS4U2EBh0yjo0HQ4vFpkqium4lLWxrAZohA\n'
            '8r82clV08iLRUW3J+xvN23iPHyVDAgMBAAECgYBScRJe3iNxMvbHv+kOhe30O/jJ\n'
            'QiUlUzhtcEMk8mGwceqHvrHTcEtRKJcPC3NQvALcp9lSQQhRzjQ1PLXkC6BcfKFd\n'
            '03q5tVPmJiqsHbSyUyHWzdlHP42xWpl/RmX/DfRKGhPOvufZpSTzkmKWtN+7osHu\n'
            '7eiMpg2EDswCvOgf0QJBAPXLYwHbZLaM2KEMDgJSse5ZTE/0VMf+5vSTGUmHkr9c\n'
            'Wx2G1i258kc/JgsXInPbq4BnK9hd0Xj2T5cmEmQtm4UCQQDJc02DFnPnjPnnDUwg\n'
            'BPhrCyW+rnBGUVjehveu4XgbGx7l3wsbORTaKdCX3HIKUupgfFwFcDlMUzUy6fPO\n'
            'IuQnAkA8FhVE/fIX4kSO0hiWnsqafr/2B7+2CG1DOraC0B6ioxwvEqhHE17T5e8R\n'
            '5PzqH7hEMnR4dy7fCC+avpbeYHvVAkA5W58iR+5Qa49r/hlCtKeWsuHYXQqSuu62\n'
            'zW8QWBo+fYZapRsgcSxCwc0msBm4XstlFYON+NoXpUlsabiFZOHZAkEA8Ffq3xoU\n'
            'y0eYGy3MEzxx96F+tkl59lfkwHKWchWZJ95vAKWJaHx9WFxSWiJofbRna8Iim6pY\n'
            'BootYWyTCfjjwA==\n'
            '-----END PRIVATE KEY-----\n'
        ),
        'client_email': 'testing-pydantic-ai@pydantic-ai.iam.gserviceaccount.com',
        'client_id': '123',
        'auth_uri': 'https://accounts.google.com/o/oauth2/auth',
        'token_uri': 'https://oauth2.googleapis.com/token',
        'auth_provider_x509_cert_url': 'https://www.googleapis.com/oauth2/v1/certs',
        'client_x509_cert_url': 'https://www.googleapis.com/...',
        'universe_domain': 'googleapis.com',
    }
def save_service_account(service_account_path: Path, project_id: str) -> None:
    service_account = prepare_service_account_contents(project_id)
    service_account_path.write_text(json.dumps(service_account, indent=2))
@pytest.fixture(autouse=True)
def vertex_provider_auth(mocker: MockerFixture) -> None:  # pragma: lax no cover
    # Locally, we authenticate via `gcloud` CLI, so we don't need to patch anything.
    if not os.getenv('CI'):
        return
    @dataclass
    class NoOpCredentials:
        token = 'my-token'
        def refresh(self, request: Request): ...
    return_value = (NoOpCredentials(), 'pydantic-ai')
    mocker.patch('pydantic_ai.providers.google_vertex.google.auth.default', return_value=return_value)
@pytest.mark.skipif(
    not os.getenv('CI', False), reason='Requires properly configured local google vertex config to pass'
)
@pytest.mark.vcr()
async def test_vertexai_provider(allow_model_requests: None):  # pragma: lax no cover
    m = GeminiModel('gemini-2.0-flash', provider='google-vertex')
    agent = Agent(m)
    result = await agent.run('What is the capital of France?')
    assert result.output == snapshot('The capital of France is **Paris**.\n')