# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Testing sample for multi server."""
from __future__ import annotations
import asyncio
import time
from typing import Any, cast
import structlog
from litestar import Controller, Litestar, get, post # type: ignore
from litestar.datastructures import State # type: ignore
from litestar.logging.config import LoggingConfig # type: ignore
from litestar.middleware.base import AbstractMiddleware # type: ignore
from litestar.plugins.structlog import StructlogPlugin # type: ignore
from litestar.types import Message # type: ignore
from starlette.applications import Starlette
from genkit.aio.loop import run_loop
from genkit.core.environment import is_dev_environment
from genkit.core.reflection import create_reflection_asgi_app
from genkit.core.registry import Registry
from genkit.web.manager import (
AbstractBaseServer,
Server,
ServerConfig,
ServerManager,
UvicornAdapter,
get_health_info,
get_server_info,
)
from genkit.web.manager.signals import terminate_all_servers
from genkit.web.typing import Application, Receive, Scope, Send
# TODO: Logging middleware > log ALL access requests and fix dups
# TODO: Logging middleware > access requests different color for each server.
# TODO: Logging middleware > show the METHOD and path first and then the structure.
# TODO: Logging middleware > if the response is an error code, highlight in red
# when logging to the console.
# TODO: Logger > default configuration and console output and json output
# TODO: Add opentelemetry integration
# TODO: replace 'requests' with 'aiohttp' or 'httpx' in genkit
logging_config = LoggingConfig(
loggers={
'genkit_example': {
'level': 'DEBUG',
'handlers': ['console'],
},
}
)
logger = structlog.get_logger(__name__)
class LitestarLoggingMiddleware(AbstractMiddleware):
"""Logging middleware for Litestar that logs requests and responses."""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process the ASGI request/response cycle with logging."""
if str(scope['type']) != 'http':
await self.app(scope, receive, send)
return
start_time = time.time()
path = scope.get('path', '')
method = scope.get('method', '')
# Log the request
request_id = str(id(scope))
try:
# Extract request headers
headers = dict(scope.get('headers', []))
formatted_headers = {k.decode('utf-8'): v.decode('utf-8') for k, v in headers.items()}
await logger.ainfo(
f'HTTP Request {method} {path}',
request_id=request_id,
method=method,
path=path,
headers=formatted_headers,
)
except Exception as e:
await logger.aerror(
'Error logging request',
error=str(e),
)
# Capture the response
async def wrapped_send(message: Message) -> None:
if message['type'] == 'http.response.start':
status_code = message.get('status', 0)
response_time = time.time() - start_time
try:
# Get response headers
resp_headers = message.get('headers', [])
formatted_resp_headers = (
{k.decode('utf-8'): v.decode('utf-8') for k, v in resp_headers} if resp_headers else {}
)
await logger.ainfo(
f'HTTP Response {method} {path}',
request_id=request_id,
method=method,
path=path,
status_code=status_code,
response_time_ms=round(response_time * 1000, 2),
headers=formatted_resp_headers,
)
except Exception as e:
await logger.aerror(
'Error logging response',
error=str(e),
)
await send(message)
# Call the next middleware or handler
await self.app(scope, receive, wrapped_send)
class BaseControllerMixin:
"""Base controller mixin for all litestar controllers."""
@post('/__quitquitquitz')
async def quit(self) -> dict[str, Any]:
"""Handle the quit endpoint."""
await logger.ainfo('Shutting down all servers...')
terminate_all_servers()
return {'status': 'OK'}
@get('/__healthz')
async def health(self, state: State) -> dict[str, Any]:
"""Handle the health check endpoint."""
config = state.config
info = get_health_info(config)
return info
@get('/__serverz')
async def server_info(self, state: State) -> dict[str, Any]:
"""Handle the system information check endpoint."""
config = state.config
info = get_server_info(config)
return info if isinstance(info, dict) else {'info': info}
class FlowsEndpoints(Controller, BaseControllerMixin):
"""Controller for the Flows API endpoints."""
path = '/flow'
@get('/run')
async def root(self) -> dict[str, str]:
"""Handle the root endpoint."""
msg = 'Running flow endpoint!'
return {'flow': msg}
class GreetingEndpoints(Controller, BaseControllerMixin):
"""Controller for the Greetings API endpoints.
An example demonstrating multiple controllers bound to the same application
server.
"""
path = '/'
@get('/greet')
async def root(self) -> dict[str, str]:
"""Handle the root endpoint."""
msg = 'Hello from greeting endpoints app!'
return {'greeting': msg}
class FlowsServerLifecycle(AbstractBaseServer):
"""Flows server implementing the ServerLifecycleProtocol."""
def __init__(self, route_handlers: list[type[Controller]]) -> None:
"""Initialize the flows server.
Args:
route_handlers: The controller classes to use for routes.
"""
self.route_handlers = route_handlers
def create(self, config: ServerConfig) -> Application:
"""Create a Litestar application instance."""
async def on_app_startup() -> None:
"""Handle application startup."""
await logger.ainfo('[LIFESPAN] Starting API server...')
# Any initialization could go here
async def on_app_shutdown() -> None:
"""Handle application shutdown."""
await logger.ainfo('[LIFESPAN] Shutting down API server...')
# Create and return the Litestar application
return Litestar(
route_handlers=self.route_handlers,
on_startup=[on_app_startup],
on_shutdown=[on_app_shutdown],
logging_config=logging_config,
middleware=[LitestarLoggingMiddleware],
plugins=[StructlogPlugin()],
state=State({'config': config}), # Set the config in the application state
)
class ReflectionServerStarletteLifecycle(AbstractBaseServer):
"""Reflection server implemented using Starlette."""
def __init__(self) -> None:
"""Initialize the Starlette reflection server."""
pass
def create(self, config: ServerConfig) -> Starlette:
"""Create a Starlette application instance."""
async def on_app_startup() -> None:
"""Handle application startup."""
await logger.ainfo('[LIFESPAN] Starting Starlette Reflection API server...')
async def on_app_shutdown() -> None:
"""Handle application shutdown."""
await logger.ainfo('[LIFESPAN] Shutting down Starlette Reflection API server...')
return cast(
Starlette,
create_reflection_asgi_app(
registry=Registry(),
on_app_startup=on_app_startup,
on_app_shutdown=on_app_shutdown,
),
)
async def add_server_after(mgr: ServerManager, server: Server, delay: float) -> None:
"""Add a server to the servers manager after a delay.
Args:
mgr: The servers manager.
server: The server to add.
delay: The delay in seconds before adding the server.
Returns:
None
"""
await asyncio.sleep(delay)
await mgr.queue_server(server)
async def main() -> None:
"""Entry point function."""
servers = [
Server(
config=ServerConfig(
name='flows',
host='localhost',
port=3400,
ports=range(3400, 3410),
),
lifecycle=FlowsServerLifecycle([FlowsEndpoints, GreetingEndpoints]),
adapter=UvicornAdapter(),
),
]
mgr = ServerManager()
if is_dev_environment():
reflection_server = Server(
config=ServerConfig(
name='reflection-starlette',
host='localhost',
port=3100,
ports=range(3100, 3110),
),
lifecycle=ReflectionServerStarletteLifecycle(),
adapter=UvicornAdapter(),
)
asyncio.create_task(add_server_after(mgr, reflection_server, 2.0))
await logger.ainfo('Starting servers...')
await mgr.run_all(servers)
if __name__ == '__main__':
run_loop(main())