index.ts•13.3 kB
import { ToolDefinition, ToolResponse } from '../types.js';
import { Config } from '../../config/config.js';
import mysql from 'mysql2/promise';
import pkg from 'pg';
const { Client: PgClient } = pkg;
import { createClient } from 'redis';
import * as path from 'path';
type DatabaseType = 'mysql' | 'postgres' | 'redis';
interface DatabaseConnection {
type: DatabaseType;
connection: any;
transaction?: boolean;
}
// 数据库管理器类
class DatabaseManager {
private connections: Map<string, DatabaseConnection>;
private dbDir: string;
constructor(
private readonly config: Config
) {
this.connections = new Map();
// 使用当前目录下的 db 目录存储数据库文件
this.dbDir = path.join(process.cwd(), 'db');
}
private async createMySQLConnection(id: string): Promise<void> {
const config = this.config.database.mysql;
const connection = await mysql.createConnection({
host: config.host,
port: config.port,
user: config.user,
password: config.password,
database: config.database
});
this.connections.set(id, { type: 'mysql', connection });
}
private async createPostgresConnection(id: string): Promise<void> {
const config = this.config.database.postgres;
const client = new PgClient({
host: config.host,
port: config.port,
user: config.user,
password: config.password,
database: config.database
});
await client.connect();
this.connections.set(id, { type: 'postgres', connection: client });
}
private async createRedisConnection(id: string): Promise<void> {
const config = this.config.database.redis;
const client = createClient({
socket: {
host: config.host,
port: config.port
},
password: config.password
});
await client.connect();
this.connections.set(id, { type: 'redis', connection: client });
}
private async getConnection(id: string): Promise<DatabaseConnection> {
const conn = this.connections.get(id);
if (!conn) {
throw new Error(`Connection ${id} not found`);
}
return conn;
}
private async closeConnection(conn: DatabaseConnection): Promise<void> {
if (conn.transaction) {
switch (conn.type) {
case 'mysql':
await conn.connection.rollback();
break;
case 'postgres':
await conn.connection.query('ROLLBACK');
break;
}
}
switch (conn.type) {
case 'mysql':
case 'postgres':
await conn.connection.end();
break;
case 'redis':
await conn.connection.quit();
break;
}
}
private formatQueryResult(result: any): string {
if (Array.isArray(result)) {
if (result.length === 0) {
return '查询结果为空';
}
return JSON.stringify(result, null, 2);
}
return JSON.stringify(result, null, 2);
}
// 数据库连接处理器
async connect(args: { type: DatabaseType; id: string }): Promise<ToolResponse> {
try {
if (this.connections.has(args.id)) {
return {
content: [{
type: 'text',
text: `连接ID "${args.id}" 已存在`
}],
isError: true
};
}
switch (args.type) {
case 'mysql':
await this.createMySQLConnection(args.id);
break;
case 'postgres':
await this.createPostgresConnection(args.id);
break;
case 'redis':
await this.createRedisConnection(args.id);
break;
default:
return {
content: [{
type: 'text',
text: '不支持的数据库类型'
}],
isError: true
};
}
return {
content: [{
type: 'text',
text: `已成功连接到${args.type}数据库 (ID: ${args.id})`
}]
};
} catch (error) {
return {
content: [{
type: 'text',
text: error instanceof Error ? error.message : String(error)
}],
isError: true
};
}
}
// 数据库查询处理器
async query(args: { connectionId: string; query: string; params?: any[] }): Promise<ToolResponse> {
try {
const conn = await this.getConnection(args.connectionId);
let result;
switch (conn.type) {
case 'mysql':
[result] = await conn.connection.execute(args.query, args.params);
break;
case 'postgres':
result = (await conn.connection.query(args.query, args.params)).rows;
break;
case 'redis':
result = await conn.connection.sendCommand(args.query.split(' '));
break;
}
return {
content: [{
type: 'text',
text: this.formatQueryResult(result)
}]
};
} catch (error) {
return {
content: [{
type: 'text',
text: error instanceof Error ? error.message : String(error)
}],
isError: true
};
}
}
// 事务处理器
async beginTransaction(args: { connectionId: string }): Promise<ToolResponse> {
try {
const conn = await this.getConnection(args.connectionId);
if (conn.type === 'redis') {
return {
content: [{
type: 'text',
text: 'Redis不支持事务操作'
}],
isError: true
};
}
if (conn.transaction) {
return {
content: [{
type: 'text',
text: '已有事务在进行中'
}],
isError: true
};
}
switch (conn.type) {
case 'mysql':
await conn.connection.beginTransaction();
break;
case 'postgres':
await conn.connection.query('BEGIN');
break;
}
conn.transaction = true;
return {
content: [{
type: 'text',
text: '事务已开始'
}]
};
} catch (error) {
return {
content: [{
type: 'text',
text: error instanceof Error ? error.message : String(error)
}],
isError: true
};
}
}
async commitTransaction(args: { connectionId: string }): Promise<ToolResponse> {
try {
const conn = await this.getConnection(args.connectionId);
if (!conn.transaction) {
return {
content: [{
type: 'text',
text: '没有活动的事务'
}],
isError: true
};
}
switch (conn.type) {
case 'mysql':
await conn.connection.commit();
break;
case 'postgres':
await conn.connection.query('COMMIT');
break;
}
conn.transaction = false;
return {
content: [{
type: 'text',
text: '事务已提交'
}]
};
} catch (error) {
return {
content: [{
type: 'text',
text: error instanceof Error ? error.message : String(error)
}],
isError: true
};
}
}
async rollbackTransaction(args: { connectionId: string }): Promise<ToolResponse> {
try {
const conn = await this.getConnection(args.connectionId);
if (!conn.transaction) {
return {
content: [{
type: 'text',
text: '没有活动的事务'
}],
isError: true
};
}
switch (conn.type) {
case 'mysql':
await conn.connection.rollback();
break;
case 'postgres':
await conn.connection.query('ROLLBACK');
break;
}
conn.transaction = false;
return {
content: [{
type: 'text',
text: '事务已回滚'
}]
};
} catch (error) {
return {
content: [{
type: 'text',
text: error instanceof Error ? error.message : String(error)
}],
isError: true
};
}
}
// 关闭连接处理器
async close(args: { connectionId: string }): Promise<ToolResponse> {
try {
const conn = await this.getConnection(args.connectionId);
await this.closeConnection(conn);
this.connections.delete(args.connectionId);
return {
content: [{
type: 'text',
text: `数据库连接已关闭 (ID: ${args.connectionId})`
}]
};
} catch (error) {
return {
content: [{
type: 'text',
text: error instanceof Error ? error.message : String(error)
}],
isError: true
};
}
}
// 清理所有资源
async dispose(): Promise<void> {
for (const [id, conn] of this.connections.entries()) {
try {
await this.closeConnection(conn);
} catch (error) {
console.error(`关闭连接失败 (${id}):`, error);
}
}
this.connections.clear();
}
}
// 创建数据库工具
export function createDatabaseTools(
config: Config
): ToolDefinition[] {
const manager = new DatabaseManager(config);
return [
{
name: 'db_connect',
description: '连接到数据库。在执行任何数据库操作前,我会先使用此工具建立连接。我会自动从配置文件读取连接信息,无需手动输入敏感信息。',
inputSchema: {
type: 'object',
properties: {
type: {
type: 'string',
description: '数据库类型:mysql(MySQL数据库)、postgres(PostgreSQL数据库)或 redis(Redis缓存)。配置信息从 config.yaml 中读取。',
enum: ['mysql', 'postgres', 'redis'],
},
id: {
type: 'string',
description: '连接标识符,用于在后续操作中引用此连接。我会使用有意义的名称,如 "main_db" 或 "cache"',
},
},
required: ['type', 'id'],
},
handler: args => manager.connect(args)
},
{
name: 'db_query',
description: '执行数据库查询。我会使用参数化查询来防止 SQL 注入。对于 Redis,我会自动将命令解析为正确的格式。',
inputSchema: {
type: 'object',
properties: {
connectionId: {
type: 'string',
description: '之前通过 db_connect 创建的连接标识符',
},
query: {
type: 'string',
description: 'SQL查询语句或Redis命令。SQL查询支持参数化(使用 ? 或 $1 占位符),Redis命令直接写命令名和参数,如 "SET key value"',
},
params: {
type: 'array',
description: '查询参数数组,用于填充 SQL 查询中的占位符。对 Redis 命令无效。使用参数化查询可以防止 SQL 注入。',
items: {
type: 'any',
},
},
},
required: ['connectionId', 'query'],
},
handler: args => manager.query(args)
},
{
name: 'db_begin_transaction',
description: '开始数据库事务。当我需要确保多个操作的原子性时使用。仅支持 MySQL 和 PostgreSQL。',
inputSchema: {
type: 'object',
properties: {
connectionId: {
type: 'string',
description: '数据库连接标识符。必须是 MySQL 或 PostgreSQL 连接。Redis 连接会返回错误。',
},
},
required: ['connectionId'],
},
handler: args => manager.beginTransaction(args)
},
{
name: 'db_commit_transaction',
description: '提交数据库事务。确认事务中的所有更改。如果事务中有任何操作失败,我会使用 db_rollback_transaction。',
inputSchema: {
type: 'object',
properties: {
connectionId: {
type: 'string',
description: '具有活动事务的数据库连接标识符。必须先调用 db_begin_transaction。',
},
},
required: ['connectionId'],
},
handler: args => manager.commitTransaction(args)
},
{
name: 'db_rollback_transaction',
description: '回滚数据库事务。当发生错误或需要撤销更改时,我会使用此工具来取消事务中的所有更改。',
inputSchema: {
type: 'object',
properties: {
connectionId: {
type: 'string',
description: '具有活动事务的数据库连接标识符。必须先调用 db_begin_transaction。',
},
},
required: ['connectionId'],
},
handler: args => manager.rollbackTransaction(args)
},
{
name: 'db_close',
description: '关闭数据库连接。我会在完成操作后调用此工具释放资源。如果有活动事务会自动回滚。',
inputSchema: {
type: 'object',
properties: {
connectionId: {
type: 'string',
description: '要关闭的数据库连接标识符。关闭后此标识符将无效,需要重新连接才能使用。',
},
},
required: ['connectionId'],
},
handler: args => manager.close(args)
}
];
}