Skip to main content
Glama

MCPDemo - Visual SQL Chat Platform

by Ayi456
SqlAiService.ts20.8 kB
import { Connection, Schema, Column } from '../types/sql.types.js'; import fetch from 'node-fetch'; import crypto from 'crypto'; import { isReadOnlySql } from '../utils/sqlUtils.js'; export interface SqlGenerationContext { connection: Connection; database?: string; tableSchema?: Schema[]; previousQueries?: string[]; } export interface SqlGenerationResult { sql: string; explanation: string; confidence: number; suggestedDatabase?: string; warnings?: string[]; } export interface AiConfig { apiUrl?: string; apiKey?: string; model?: string; temperature?: number; maxTokens?: number; timeout?: number; enableCache?: boolean; cacheTTL?: number; } // 环境配置 interface EnvironmentProfile { name: string; ai: { timeout: number; maxTokens: number; temperature: number; }; cache: { enabled: boolean; ttl: number; }; schema: { maxTables: number; maxColumnsPerTable: number; maxSchemaLength: number; }; } export class SqlAiService { private aiConfig: AiConfig; private isScfEnvironment: boolean; private profile: EnvironmentProfile; // 环境配置预设 private readonly profiles: Record<string, EnvironmentProfile> = { local: { name: 'local', ai: { timeout: 30000, maxTokens: 2048, temperature: 0.7 }, cache: { enabled: false, ttl: 0 }, schema: { maxTables: -1, // 无限制 maxColumnsPerTable: -1, maxSchemaLength: -1 } }, scf: { name: 'scf', ai: { timeout: 25000, maxTokens: 1500, temperature: 0.5 }, cache: { enabled: true, ttl: 3600 // 缓存 1 小时 }, schema: { maxTables: 10, maxColumnsPerTable: 20, maxSchemaLength: 2000 } } }; constructor(config?: AiConfig) { // 检测运行环境 this.isScfEnvironment = process.env.FORCE_CLOUD_FUNCTION === 'true'; // 选择环境配置 this.profile = this.isScfEnvironment ? this.profiles.scf : this.profiles.local; this.aiConfig = { apiUrl: config?.apiUrl || process.env.AI_API_URL, apiKey: config?.apiKey || process.env.AI_API_KEY, model: config?.model || process.env.AI_MODEL, temperature: config?.temperature || parseFloat(process.env.AI_TEMPERATURE || String(this.profile.ai.temperature)), maxTokens: config?.maxTokens || parseInt(process.env.AI_MAX_TOKENS || String(this.profile.ai.maxTokens)), timeout: config?.timeout || parseInt(process.env.AI_TIMEOUT || String(this.profile.ai.timeout)), enableCache: config?.enableCache ?? (process.env.ENABLE_AI_CACHE === 'true' || this.profile.cache.enabled), cacheTTL: config?.cacheTTL || parseInt(process.env.AI_CACHE_TTL || String(this.profile.cache.ttl)) }; // 必填校验:缺少 API Key 直接抛错,避免隐式使用无效凭证 if (!this.aiConfig.apiKey) { throw new Error('Missing AI_API_KEY: please set env AI_API_KEY or pass SqlAiService({ apiKey })'); } console.log(`SqlAiService initialized:`, { environment: this.profile.name, timeout: this.aiConfig.timeout, maxTokens: this.aiConfig.maxTokens, cacheEnabled: this.aiConfig.enableCache }); } /** * 根据自然语言生成 SQL(AI-only + 重试) */ async generateSql( message: string, context: SqlGenerationContext ): Promise<SqlGenerationResult> { const startTime = Date.now(); // 1) 可选:缓存命中直接返回 if (this.aiConfig.enableCache) { try { const cached = await this.getCachedResult(message, context); if (cached) { console.log(`缓存命中,耗时: ${Date.now() - startTime}ms`); return cached; } } catch (e) { console.warn('读取缓存失败,继续调用 AI:', e); } } // 2) 构建 schema 上下文 const schemaContext = this.buildSchemaContext(context.tableSchema); // 3) 读取重试与置信度参数 const attempts = Math.max(1, parseInt(process.env.AI_RETRY_ATTEMPTS || '2', 10)); const baseDelay = Math.max(0, parseInt(process.env.AI_RETRY_BASE_DELAY || '800', 10)); const minConfidence = Math.min(1, Math.max(0, parseFloat(process.env.MIN_SQL_CONFIDENCE || '0.7'))); let lastError: Error | null = null; for (let i = 0; i < attempts; i++) { try { const aiResult = await this.generateSqlWithAI(message, schemaContext, context); if (aiResult && this.isValidAiResult(aiResult, minConfidence)) { // 命中缓存条件再写缓存(可选) try { if (this.aiConfig.enableCache && (aiResult.confidence ?? 0) > 0.7) { await this.cacheResult(message, context, aiResult); } } catch (e) { console.warn('写入缓存失败(已忽略):', e); } console.log(`AI 生成成功,耗时: ${Date.now() - startTime}ms,尝试次数: ${i + 1}`); return aiResult; } lastError = new Error(`AI 返回无效结构或置信度过低(${aiResult?.confidence ?? 'N/A'})`); } catch (e: any) { lastError = e instanceof Error ? e : new Error(String(e)); } // 指数回退延时后重试 if (i < attempts - 1) { await this.sleep(baseDelay * (i + 1)); } } // 4) 全部失败(仅 AI-only,不再规则引擎降级) const failMsg = lastError?.message || 'AI 生成失败'; console.error('SQL 生成错误:', failMsg); if (this.isScfEnvironment) { return { sql: '', explanation: '生成 SQL 失败或结果不可靠,请重试或补充更清晰的需求', confidence: 0, warnings: [failMsg] }; } throw lastError || new Error('AI 生成失败'); } /** * 使用 AI 模型生成 SQL(带超时控制) */ private async generateSqlWithAI( message: string, schemaContext: string, context: SqlGenerationContext ): Promise<SqlGenerationResult | null> { // 创建 AbortController 用于超时控制 const controller = new AbortController(); const timeoutId = setTimeout(() => { controller.abort(); console.log(`AI 请求超时 (${this.aiConfig.timeout}ms)`); }, this.aiConfig.timeout!); try { // 构建优化的系统提示词 const systemPrompt = this.buildSystemPrompt(context, schemaContext); const userPrompt = this.buildUserPrompt(message, context); const requestBody: any = { model: this.aiConfig.model, messages: [ { role: 'system', content: systemPrompt }, { role: 'user', content: userPrompt } ], temperature: this.aiConfig.temperature, max_tokens: this.aiConfig.maxTokens }; // 某些供应商(如 ModelScope)可能不支持 response_format if (!/modelscope\.cn/i.test(this.aiConfig.apiUrl || '')) { requestBody.response_format = { type: 'json_object' }; } const baseUrl = (this.aiConfig.apiUrl || '').replace(/\/+$/, ''); const finalUrl = `${baseUrl}/chat/completions`; const response = await fetch(finalUrl, { method: 'POST', headers: { 'Content-Type': 'application/json', 'Accept': 'application/json', 'Authorization': `Bearer ${this.aiConfig.apiKey}` }, body: JSON.stringify(requestBody), signal: controller.signal // 添加超时信号 }); clearTimeout(timeoutId); // 清除超时定时器 if (!response.ok) { const text = await response.text().catch(() => ''); console.error('AI API 请求失败:', response.status, response.statusText, text, 'URL:', finalUrl); return null; } const data = await response.json().catch(() => ({})) as any; // 兼容不同供应商响应结构 let content = data?.choices?.[0]?.message?.content ?? data?.choices?.[0]?.text ?? data?.output_text ?? ''; if (!content || typeof content !== 'string') { console.error('AI 响应内容为空或非字符串'); return null; } // 解析 JSON 响应:优先整体解析,失败则截取第一个花括号包裹的 JSON let aiResponse: any; try { aiResponse = JSON.parse(content); } catch { const start = content.indexOf('{'); const end = content.lastIndexOf('}'); if (start >= 0 && end > start) { const jsonStr = content.slice(start, end + 1); try { aiResponse = JSON.parse(jsonStr); } catch { aiResponse = null; } } } if (!aiResponse) { // 退化:返回解释文本作为说明,无 SQL return { sql: '', explanation: content.slice(0, 500), confidence: 0.4, suggestedDatabase: context.database, warnings: ['AI 返回非结构化文本,已降级为说明文本'] }; } // 清洗 SQL,移除可能的 Markdown 代码块、语言标签与包裹引号 const cleanedSql = this.sanitizeSql(aiResponse.sql || ''); return { sql: cleanedSql, explanation: aiResponse.explanation || '基于您的需求生成的 SQL 查询', confidence: typeof aiResponse.confidence === 'number' ? aiResponse.confidence : 0.8, suggestedDatabase: context.database, warnings: aiResponse.warnings }; } catch (error: any) { clearTimeout(timeoutId); // 确保清除定时器 if (error.name === 'AbortError') { console.log('AI 请求超时'); } else { console.error('AI 生成 SQL 失败:', error); } return null; } } /** * 构建 schema 上下文描述(优化版:根据环境智能裁剪) */ private buildSchemaContext(tableSchema?: Schema[]): string { if (!tableSchema || tableSchema.length === 0) { return '暂无数据库架构信息'; } // SCF 环境下进行智能裁剪 if (this.isScfEnvironment && this.profile.schema.maxTables > 0) { return this.buildOptimizedSchema(tableSchema); } // 本地环境返回完整 Schema const schemaDesc = tableSchema.map(table => { if (!table.columns || table.columns.length === 0) { return `表 ${table.name}: 无列信息`; } const columns = table.columns.map((col: Column) => `${col.name} (${col.type}${col.nullable ? ', 可为空' : ''})` ).join(', '); return `表 ${table.name}: ${columns}`; }).join('\n'); return schemaDesc; } private buildOptimizedSchema(tableSchema: Schema[]): string { const { maxTables, maxColumnsPerTable, maxSchemaLength } = this.profile.schema; let result = ''; let charCount = 0; // 限制表数量 const tablesToProcess = tableSchema.slice(0, maxTables); for (let i = 0; i < tablesToProcess.length; i++) { const table = tablesToProcess[i]; // 限制列数量 const columns = (table.columns || []) .slice(0, maxColumnsPerTable) .map((col: Column) => { // 紧凑格式:name:type const nullable = col.nullable ? '' : '!'; return `${col.name}:${this.shortenType(col.type)}${nullable}`; }) .join(','); const tableInfo = `${table.name}(${columns})`; // 检查总长度限制 if (maxSchemaLength > 0 && charCount + tableInfo.length > maxSchemaLength) { result += `...(还有${tableSchema.length - i}个表)`; break; } if (i > 0) result += '; '; result += tableInfo; charCount += tableInfo.length; // 如果列被截断,添加提示 if (table.columns && table.columns.length > maxColumnsPerTable) { result += `[+${table.columns.length - maxColumnsPerTable}列]`; } } // 如果表被截断,添加提示 if (tableSchema.length > maxTables) { result += ` ...(共${tableSchema.length}个表)`; } return result; } /** * 缩短类型名称以节省 Token */ private shortenType(type: string): string { const typeMap: Record<string, string> = { 'varchar': 'str', 'integer': 'int', 'bigint': 'bigint', 'decimal': 'dec', 'datetime': 'dt', 'timestamp': 'ts', 'boolean': 'bool', 'text': 'txt' }; const lowerType = type.toLowerCase(); for (const [full, short] of Object.entries(typeMap)) { if (lowerType.includes(full)) { return short; } } // 如果是带长度的类型,提取数字 const match = type.match(/\((\d+)\)/); if (match) { return type.substring(0, 3) + match[1]; } return type.substring(0, 4); // 默认取前4个字符 } private sanitizeSql(sql: string): string { if (!sql) return ''; let s = String(sql).trim(); // 去除 Markdown 代码块包裹,例如 ```sql ... ``` 或 ``` ... ``` s = s.replace(/^```(?:sql|postgresql|postgres|mysql)?\s*/i, ''); s = s.replace(/\s*```$/i, ''); // 去除首尾包裹的引号 s = s.replace(/^["']+|["']+$/g, ''); // 统一换行与空白 s = s.replace(/\r\n/g, '\n').trim(); return s; } /** * 获取缓存的结果 */ private async getCachedResult( message: string, context: SqlGenerationContext ): Promise<SqlGenerationResult | null> { if (!this.aiConfig.enableCache) { return null; } try { const { getRedisClient } = await import('../database.js'); const redis = getRedisClient(); const cacheKey = this.generateCacheKey(message, context); const cached = await redis.get(cacheKey); if (cached) { console.log(`✅ AI 缓存命中: ${cacheKey}`); return JSON.parse(cached) as SqlGenerationResult; } return null; } catch (error) { console.warn('Redis 缓存读取失败:', error); return null; } } /** * 缓存结果 */ private async cacheResult( message: string, context: SqlGenerationContext, result: SqlGenerationResult ): Promise<void> { if (!this.aiConfig.enableCache) { return; } try { const { getRedisClient } = await import('../database.js'); const redis = getRedisClient(); const cacheKey = this.generateCacheKey(message, context); const ttl = this.calculateCacheTTL(result); await redis.setEx( cacheKey, ttl, JSON.stringify(result) ); console.log(`✅ AI 缓存写入: ${cacheKey}, TTL: ${ttl}s, 置信度: ${result.confidence}`); } catch (error) { console.warn('Redis 缓存写入失败:', error); // 缓存失败不影响主流程 } } /** * 生成缓存键 */ private generateCacheKey( message: string, context: SqlGenerationContext ): string { // 规范化查询语句 const normalizedMessage = message .toLowerCase() .replace(/\s+/g, ' ') .trim(); // 创建上下文哈希 const contextData = { message: normalizedMessage, engineType: context.connection.engineType, database: context.database, tableCount: context.tableSchema?.length || 0 }; const hash = crypto .createHash('md5') .update(JSON.stringify(contextData)) .digest('hex'); return `sql:ai:${hash}`; } /** * 根据结果质量计算缓存时长 */ private calculateCacheTTL(result: SqlGenerationResult): number { const baseTTL = this.aiConfig.cacheTTL || 3600; // 根据置信度调整 TTL if (result.confidence > 0.9) { return baseTTL * 2; // 高置信度,缓存更久 } else if (result.confidence > 0.7) { return baseTTL; // 中等置信度,正常缓存 } else { return baseTTL / 2; // 低置信度,缓存时间减半 } } /** * 构建优化的系统提示词 */ private buildSystemPrompt(context: SqlGenerationContext, schemaContext: string): string { const engineType = context.connection.engineType.toUpperCase(); const database = context.database || 'default'; const idQuoteStart = engineType === 'MYSQL' ? '`' : engineType === 'POSTGRESQL' ? '"' : ''; const idQuoteEnd = idQuoteStart; const opening = [ `你是一名 ${engineType} 数据库与 SQL 专家。`, '当被问及你的名字时,你必须回答 "Sir"。', '你的回答应信息充分且简洁。' ].join('\n'); let schemaLead = ''; if (schemaContext && schemaContext.trim() && schemaContext !== '暂无数据库架构信息') { schemaLead = `\n\n以下是我的数据库 Schema:\n\n${schemaContext}\n\n请基于上述 Schema 回答后续问题:`; } const strictRules = `严格规则:\n- 仅生成一条只读查询:SELECT / WITH ... SELECT / EXPLAIN SELECT\n- 严禁任何修改/DDL 语句:INSERT/UPDATE/DELETE/CREATE/ALTER/DROP/TRUNCATE/GRANT/REVOKE/REPLACE/MERGE/CALL/DO/USE/SET/COMMIT/ROLLBACK\n- 严禁多语句与分号;不得包含注释或提示(如 /*+ */)\n- 避免 SELECT *,显式列出字段并适当使用别名\n- 若可能返回大量数据,默认添加 LIMIT 100(或等价语法),除非用户明确要求\n- 仅使用 Schema 中存在的表与列;若缺失,请在 warnings 中说明\n- 标识符建议使用 ${engineType === 'MYSQL' ? '反引号 `identifier`' : engineType === 'POSTGRESQL' ? '双引号 "identifier"' : '合适的引用符'};字符串使用单引号\n- 不要使用 Markdown 代码块(三个反引号)或额外文本,只返回 JSON 对象\n\n输出格式(仅 JSON 对象本身,禁止任何额外文本/Markdown 代码块):\n{\n \"sql\": \"...\",\n \"explanation\": \"中文说明,简洁(≤120字)\",\n \"confidence\": 0.0-1.0,\n \"warnings\": []\n}`; // 组装提示词 let prompt = `${opening}${schemaLead}\n\n${strictRules}\n\n上下文:\n- 数据库系统:${engineType}\n- 当前数据库:${database}`; // 引擎提示 if (engineType === 'MYSQL') { prompt += `\n\n引擎提示(MySQL):\n- 标识符使用 ${idQuoteStart}name${idQuoteEnd};分页用 LIMIT;空值处理用 IFNULL/COALESCE\n- 日期格式化可用 DATE_FORMAT;字符串拼接用 CONCAT`; } else if (engineType === 'POSTGRESQL') { prompt += `\n\n引擎提示(PostgreSQL):\n- 标识符使用 ${idQuoteStart}name${idQuoteEnd};分页用 LIMIT/OFFSET;不区分大小写匹配可用 ILIKE\n- JSONB 提取可用 -> / ->>;空值处理用 COALESCE;拼接用 ||`; } // 失败回退 prompt += `\n\n失败回退:若需求不清或无法保证只读与正确性,请返回空 SQL(\"sql\": \"\",\"confidence\": 0)并在 warnings 中解释原因。`; return prompt; } /** * 构建优化的用户提示词 */ private buildUserPrompt(message: string, context: SqlGenerationContext): string { let prompt = `用户需求:${message}\n\n`; if (context.previousQueries && context.previousQueries.length > 0) { prompt += `最近查询上下文(取最近2条):\n`; context.previousQueries.slice(-2).forEach((q, i) => { prompt += `#${i + 1}: ${q}\n`; }); prompt += `\n`; } prompt += `请在不确定时默认限制返回行数(如 LIMIT 100),避免 SELECT *,并优先使用明确的 JOIN 条件与字段别名。\n`; prompt += `最终仅返回严格 JSON(对象本身,无任何额外文本/Markdown 代码块),字段为 sql/explanation/confidence/warnings。`; return prompt; } /** * 校验 AI 返回的结果是否可接受 */ private isValidAiResult(result: SqlGenerationResult, minConfidence: number): boolean { const hasStructure = typeof result?.sql === 'string' && typeof result?.explanation === 'string' && typeof result?.confidence === 'number'; if (!hasStructure) return false; if ((result.confidence ?? 0) < minConfidence) return false; // 只读 SQL 守卫:仅允许 SELECT / WITH ... SELECT / EXPLAIN SELECT if (!isReadOnlySql(result.sql || '')) return false; return true; } /** * 简单延时(用于指数回退重试) */ private sleep(ms: number): Promise<void> { return new Promise(resolve => setTimeout(resolve, ms)); } }

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Ayi456/visual-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server