search.pyā¢5.52 kB
"""
Search engine with semantic search using embeddings
"""
from typing import List, Optional, Tuple
import openai
import anthropic
from .models import SkillSearchResult
from .database import Database
class SearchEngine:
"""Semantic search for skills using vector embeddings"""
def __init__(self, db: Database, openai_key: Optional[str], anthropic_key: Optional[str]):
self.db = db
self.openai_client = openai.OpenAI(api_key=openai_key) if openai_key else None
self.anthropic_client = anthropic.Anthropic(api_key=anthropic_key) if anthropic_key else None
def _generate_embedding(self, text: str) -> List[float]:
"""Generate embedding vector for text"""
if self.openai_client:
# Use OpenAI ada-002 embeddings
response = self.openai_client.embeddings.create(
model="text-embedding-ada-002",
input=text
)
return response.data[0].embedding
else:
# Fallback: return zero vector (will use keyword search only)
return [0.0] * 1536
async def generate_embedding(self, skill_id: str):
"""Generate and store embedding for a skill"""
skill = await self.db.get_skill(skill_id)
if not skill:
return
# Combine searchable text
text = f"{skill.name} {skill.description or ''} {' '.join(skill.tags)}"
embedding = self._generate_embedding(text)
# Store embedding
conn = self.db._get_conn()
try:
with conn.cursor() as cur:
cur.execute("""
UPDATE skills
SET embedding = %s::vector
WHERE skill_id = %s
""", (embedding, skill_id))
conn.commit()
finally:
conn.close()
async def search(
self,
query: str,
tags: Optional[List[str]] = None,
category: Optional[str] = None,
min_rating: Optional[float] = None,
ai_generated: Optional[bool] = None,
limit: int = 10
) -> List[SkillSearchResult]:
"""
Search skills with semantic similarity and filters
"""
# Generate query embedding (used when semantic search is enabled)
query_embedding = self._generate_embedding(query)
where_clause, params = self._build_filters(tags, category, min_rating, ai_generated)
conn = self.db._get_conn()
try:
with conn.cursor() as cur:
if self.openai_client:
rows = self._execute_semantic_search(cur, where_clause, params, query_embedding, limit)
else:
rows = self._execute_keyword_search(cur, where_clause, params, query, limit)
return [SkillSearchResult(**dict(row)) for row in rows]
finally:
conn.close()
def _build_filters(
self,
tags: Optional[List[str]],
category: Optional[str],
min_rating: Optional[float],
ai_generated: Optional[bool]
) -> Tuple[str, List]:
"""Build WHERE clause and parameters for search filters."""
where_clauses = ["visibility = 'public'"]
params: List = []
if tags:
where_clauses.append("tags && %s")
params.append(tags)
if category:
where_clauses.append("category = %s")
params.append(category)
if min_rating is not None:
where_clauses.append("COALESCE(ss.rating_avg, 0) >= %s")
params.append(min_rating)
if ai_generated is not None:
where_clauses.append("ai_generated = %s")
params.append(ai_generated)
return " AND ".join(where_clauses), params
def _execute_semantic_search(
self,
cur,
where_clause: str,
params: List,
query_embedding: List[float],
limit: int
) -> List:
"""Execute semantic (vector) search query and return rows."""
cur.execute(f"""
SELECT
s.*,
COALESCE(ss.rating_avg, 0) as rating_avg,
COALESCE(ss.rating_count, 0) as rating_count,
(1 - (s.embedding <=> %s::vector)) as relevance_score
FROM skills s
LEFT JOIN skill_stats ss ON s.skill_id = ss.skill_id
WHERE {where_clause} AND s.embedding IS NOT NULL
ORDER BY relevance_score DESC
LIMIT %s
""", [query_embedding] + params + [limit])
return cur.fetchall()
def _execute_keyword_search(
self,
cur,
where_clause: str,
params: List,
query: str,
limit: int
) -> List:
"""Execute keyword-based search query and return rows."""
cur.execute(f"""
SELECT
s.*,
COALESCE(ss.rating_avg, 0) as rating_avg,
COALESCE(ss.rating_count, 0) as rating_count,
1.0 as relevance_score
FROM skills s
LEFT JOIN skill_stats ss ON s.skill_id = ss.skill_id
WHERE {where_clause}
AND (
s.name ILIKE %s OR
s.description ILIKE %s OR
%s = ANY(s.tags)
)
ORDER BY ss.rating_avg DESC NULLS LAST
LIMIT %s
""", params + [f"%{query}%", f"%{query}%", query, limit])
return cur.fetchall()