import os
import time
import logging
from typing import Optional
try:
import google.generativeai as genai
HAS_GENAI = True
except ImportError:
HAS_GENAI = False
class LLMClient:
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
self.model = None
self._setup()
def _setup(self):
if not HAS_GENAI:
logging.warning("google-generativeai library not found. LLM features disabled.")
return
if not self.api_key:
logging.warning("GEMINI_API_KEY not found. LLM features will return placeholder text.")
return
try:
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel('gemini-2.0-flash-exp')
except Exception as e:
logging.error(f"Failed to configure Gemini: {e}")
def generate_recommendation(self, context_str: str) -> str:
"""
Sends the job context to Gemini and returns the advice.
"""
if not HAS_GENAI:
return "❌ Error: `google-generativeai` library is not installed."
if not self.model:
return ("⚠️ **Gemini API Key Missing**\n\n"
"To enable AI recommendations, please set the `GEMINI_API_KEY` environment variable.\n"
"You act as a placeholder for now.")
system_prompt = (
"You are a Senior Spark Optimization Engineer. "
"Analyze the provided Spark Application Context (Metrics + Code) and provide actionable tuning recommendations. "
"Focus on: Skew, Spill, Serialization, GC, and Code Efficiency. "
"Be concise, cite specific stages, and provide config snippets."
)
full_prompt = f"{system_prompt}\n\n{context_str}"
max_retries = 5
for attempt in range(max_retries):
try:
response = self.model.generate_content(full_prompt)
return response.text
except Exception as e:
error_msg = str(e)
if "429" in error_msg or "quota" in error_msg.lower():
# Extract wait time from various formats
import re
wait_time = 5.0 * (2 ** attempt) # Default backoff
# Format 1: "Please retry in 26.6s"
match1 = re.search(r"retry in (\d+(\.\d+)?)s", error_msg)
if match1:
wait_time = float(match1.group(1))
# Format 2: "retry_delay { seconds: 58 }"
match2 = re.search(r"retry_delay\s*{\s*seconds:\s*(\d+)", error_msg)
if match2:
wait_time = float(match2.group(1))
if attempt < max_retries - 1:
logging.warning(f"Quota exceeded (Attempt {attempt+1}/{max_retries}). Retrying in {wait_time:.1f}s...")
time.sleep(wait_time + 1) # Add buffer
continue
return f"❌ Error communicating with Gemini API: {e}"
return "❌ Failed to get response after multiple retries due to rate limits."