gbi.ipynb•17.2 kB
{
"cells": [
{
"cell_type": "markdown",
"id": "f802e64d-4eaa-445d-a48a-1042a91bc394",
"metadata": {
"tags": []
},
"source": [
"# GBI\n",
"\n",
"## 目标\n",
"通过 GBI sdk 接口完成选表和问表的能力。 \n",
"\n",
"## 准备工作\n",
"### 平台注册\n",
"1.先在appbuilder平台注册,获取token\n",
"2.安装appbuilder-sdk"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2939356f-61c2-42e9-9e0c-fc6729c193f6",
"metadata": {},
"outputs": [],
"source": [
"# !pip install appbuilder-sdk"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "4ccff03b-1567-4e8b-8e1f-9a5032690406",
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"import os\n",
"\n",
"# 设置环境变量\n",
"os.environ[\"APPBUILDER_TOKEN\"] = \"***\"\n"
]
},
{
"cell_type": "markdown",
"id": "aeb2fa55-075f-48df-a9fb-8b40d9900684",
"metadata": {},
"source": [
"## 开发过程"
]
},
{
"cell_type": "markdown",
"id": "1c3c5cee",
"metadata": {},
"source": [
"### 设置表的 schema"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d7d6440c",
"metadata": {},
"outputs": [],
"source": [
"SUPER_MARKET_SCHEMA = \"\"\"\n",
"```\n",
"CREATE TABLE `supper_market_info` (\n",
" `订单编号` varchar(32) DEFAULT NULL,\n",
" `订单日期` date DEFAULT NULL,\n",
" `邮寄方式` varchar(32) DEFAULT NULL,\n",
" `地区` varchar(32) DEFAULT NULL,\n",
" `省份` varchar(32) DEFAULT NULL,\n",
" `客户类型` varchar(32) DEFAULT NULL,\n",
" `客户名称` varchar(32) DEFAULT NULL,\n",
" `商品类别` varchar(32) DEFAULT NULL,\n",
" `制造商` varchar(32) DEFAULT NULL,\n",
" `商品名称` varchar(32) DEFAULT NULL,\n",
" `数量` int(11) DEFAULT NULL,\n",
" `销售额` int(11) DEFAULT NULL,\n",
" `利润` int(11) DEFAULT NULL\n",
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n",
"```\n",
"\"\"\"\n",
"\n",
"PRODUCT_SALES_INFO = \"\"\"\n",
"现有 mysql 表 product_sales_info, \n",
"该表的用途是: 产品收入表\n",
"```\n",
"CREATE TABLE `product_sales_info` (\n",
" `年` int,\n",
" `月` int,\n",
" `产品名称` varchar,\n",
" `收入` decimal,\n",
" `非交付成本` decimal,\n",
" `含交付毛利` decimal\n",
")\n",
"```\n",
"\"\"\"\n",
"\n",
"# schema 和表名的映射\n",
"SCHEMA_MAPPING = {\n",
" \"supper_market_info\": SUPER_MARKET_SCHEMA,\n",
" \"PRODUCT_SALES_INFO\": PRODUCT_SALES_INFO\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "463254a1",
"metadata": {},
"source": [
"设置表的描述用于选表"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7fefcae1",
"metadata": {},
"outputs": [],
"source": [
"table_descriptions = {\n",
" \"supper_market_info\": \"超市营收明细表,包含超市各种信息等\",\n",
" \"product_sales_info\": \"产品销售表\"\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "a0aff843",
"metadata": {},
"source": [
"### 选表"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "41559341-fd7a-478c-a08b-1477d79e9d41",
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-18T06:24:26.982459Z",
"start_time": "2023-12-18T06:23:53.771345Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"选的表是: ['supper_market_info']\n"
]
}
],
"source": [
"import appbuilder\n",
"from appbuilder.core.message import Message\n",
"from appbuilder.core.components.gbi.basic import SessionRecord\n",
"\n",
"# 生成问表对象\n",
"select_table = appbuilder.SelectTable(model_name=\"ERNIE-Bot 4.0\", table_descriptions=table_descriptions)\n",
"query = \"列出超市中的所有数据\"\n",
"msg = Message({\"query\": query})\n",
"select_table_result_message = select_table(msg)\n",
"print(f\"选的表是: {select_table_result_message.content}\")"
]
},
{
"cell_type": "markdown",
"id": "16a8aa38-7a33-4e27-bca4-00900cfe1641",
"metadata": {},
"source": [
"### 问表\n",
"基于上面选出的表,通过获取表的 schema 进行问表"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9f45ef5f-6206-4b31-83c4-3c8eb2c86925",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sql: SELECT * FROM supper_market_info;\n",
"-----------------\n",
"llm result: ```sql\n",
"SELECT * FROM supper_market_info;\n",
"```\n"
]
}
],
"source": [
"table_schemas = [SCHEMA_MAPPING[table_name] for table_name in select_table_result_message.content]\n",
"gbi_nl2sql = appbuilder.NL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas)\n",
"nl2sql_result_message = gbi_nl2sql(Message({\"query\": \"列出超市中的所有数据\"}))\n",
"print(f\"sql: {nl2sql_result_message.content.sql}\")\n",
"print(\"-----------------\")\n",
"print(f\"llm result: {nl2sql_result_message.content.llm_result}\")"
]
},
{
"cell_type": "markdown",
"id": "b0409c46-e8c7-403a-a827-fcdc8e717be6",
"metadata": {},
"source": [
"设置 session"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a23b8cad-f426-4074-9311-c2c33aaea07b",
"metadata": {},
"outputs": [],
"source": [
"session = list()\n",
"session.append(SessionRecord(query=query, answer=nl2sql_result_message.content))"
]
},
{
"cell_type": "markdown",
"id": "22b3d877-f61f-4958-a084-7507a3017e17",
"metadata": {},
"source": [
"再次问表"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2adcb091-fb53-4364-b4d8-20564439ff51",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sql: SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n",
"-----------------\n",
"llm result: ```sql\n",
"SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n",
"```\n"
]
}
],
"source": [
"nl2sql_result_message2 = gbi_nl2sql(Message({\"query\": \"查看商品类别是水果的所有数据\", \n",
" \"session\": session}))\n",
"print(f\"sql: {nl2sql_result_message2.content.sql}\")\n",
"print(\"-----------------\")\n",
"print(f\"llm result: {nl2sql_result_message2.content.llm_result}\")"
]
},
{
"cell_type": "markdown",
"id": "9e0609ae-f2bc-43d3-9023-14e9f8618158",
"metadata": {},
"source": [
"### 增加列选优化\n",
"实际上数据中 \"商品类别\" 存储的是 \"新鲜水果\", 那么就可以通过列选的限制来优化 sql."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2a7c7923-019e-4660-9e36-4431e9d2f3a6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sql: SELECT * FROM supper_market_info WHERE 商品类别 = '新鲜水果'\n",
"-----------------\n",
"llm result: ```sql\n",
"SELECT * FROM supper_market_info WHERE 商品类别 = '新鲜水果'\n",
"```\n"
]
}
],
"source": [
"from appbuilder.core.components.gbi.basic import ColumnItem\n",
"\n",
"\n",
"column_constraint = [ColumnItem(ori_value=\"水果\", \n",
" column_name=\"商品类别\", \n",
" column_value=\"新鲜水果\", \n",
" table_name=\"supper_market_info\", \n",
" is_like=False)]\n",
"\n",
"nl2sql_result_message2 = gbi_nl2sql(Message({\"query\": \"查看商品类别是水果的所有数据\",\n",
" \"column_constraint\": column_constraint}))\n",
"\n",
"print(f\"sql: {nl2sql_result_message2.content.sql}\")\n",
"print(\"-----------------\")\n",
"print(f\"llm result: {nl2sql_result_message2.content.llm_result}\")"
]
},
{
"cell_type": "markdown",
"id": "8385312c-aea1-42cd-b61b-a8d36f4f0665",
"metadata": {},
"source": [
"从上面我们看到,商品类别不在是 \"水果\" 而是 修订为 \"新鲜水果\""
]
},
{
"cell_type": "markdown",
"id": "6e98c414-8b2b-4187-a270-3117a4f431ff",
"metadata": {},
"source": [
"### 增加知识优化\n",
"当计算某些特殊知识的时候,大模型是不知道的,所以需要告诉大模型具体的知识,比如:\n",
"利润率的计算方式: 利润/销售额\n",
"可以将该知识注入。具体示例如下:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "cade4693-29dc-431c-bf84-c6dc09104294",
"metadata": {},
"outputs": [],
"source": [
"# 注入知识\n",
"gbi_nl2sql.knowledge[\"利润率\"] = \"计算方式: 利润/销售额\""
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "1dc181e8-47a1-4b82-8bb5-ce3339be53f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sql: SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率\n",
"FROM supper_market_info\n",
"WHERE 商品类别 = '日用品'\n",
"GROUP BY 商品类别\n",
"-----------------\n",
"llm result: ```sql\n",
"SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率\n",
"FROM supper_market_info\n",
"WHERE 商品类别 = '日用品'\n",
"GROUP BY 商品类别\n",
"```\n"
]
}
],
"source": [
"nl2sql_result_message3 = gbi_nl2sql(Message({\"query\": \"列出商品类别是日用品的利润率\"}))\n",
"print(f\"sql: {nl2sql_result_message3.content.sql}\")\n",
"print(\"-----------------\")\n",
"print(f\"llm result: {nl2sql_result_message3.content.llm_result}\")"
]
},
{
"cell_type": "markdown",
"id": "c5570cd9-dbaf-45cd-ab03-1a7f92e7d0d4",
"metadata": {},
"source": [
"## 调整 prompt 模版\n",
"有时候,我们希望定义自己的prompt, 选表和问表两个环节都支持 prompt 模版的定制化,但是必须遵循对应的 prompt 模版的格式。"
]
},
{
"cell_type": "markdown",
"id": "6e3d4967-2b4c-437d-9d72-fb1b94bdcf59",
"metadata": {},
"source": [
"### 选表 prompt 调整\n",
"选表的 prompt template, 必须包含 \n",
"1. {num} - 表的数量, 注意 {num} 有两个地方出现\n",
"2. {table_desc} - 表的描述\n",
"3. {query} - query\n",
"\n",
"注意: {num}, {table_desc}, {query} 表示的是占位符,**用户不需要在自定义的 prompt template 中将这些值填充上**,gbi 系统会自动根据 SelectTable 构造函数中提交的参数进行填充这些占位符,从而产生最后的给大模型的 prompt。注意 prompt template 和 prompt 的区别。\n",
"\n",
"* prompt template - 是带有占位符的 prompt, gbi 会根据具体参数填充到这些占位符,形成最终的 prompt\n",
"* promt - 是将 prompt template 填充完占位符的结果。\n",
"\n",
"用户可以使用这些占位符重新设置自己的 prompt 模版,从而达到修改 prompt 的目的。\n",
"具体请参考下面的 prompt template 示例:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2ae6ffbc-4237-4fb2-8168-480b81bfd873",
"metadata": {},
"outputs": [],
"source": [
"SELECT_TABLE_PROMPT_TEMPLATE = \"\"\"\n",
"你是一个专业的业务人员,下面有{num}张表,具体表名如下:\n",
"{table_desc}\n",
"请根据问题帮我选择上述1-{num}种的其中相关表并返回,可以为多表,也可以为单表,\n",
"返回多张表请用“,”隔开\n",
"返回格式请参考如下示例:\n",
"问题:有多少个审核通过的投运单?\n",
"回答: ```DWD_MAT_OPERATION```\n",
"请严格参考示例只不要返回无关内容,直接给出最终答案后面的内容,分析步骤不要输出\n",
"问题:{query}\n",
"回答:\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2bbbb375-6659-4ef0-82ff-a4ace9fdd4f0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"选的表是: ['supper_market_info']\n"
]
}
],
"source": [
"select_table4 = appbuilder.SelectTable(model_name=\"ERNIE-Bot 4.0\", \n",
" table_descriptions=table_descriptions,\n",
" prompt_template=SELECT_TABLE_PROMPT_TEMPLATE)\n",
"\n",
"select_table_result_message4 = select_table4(Message({\"query\":\"列出超市中的所有数据\"}))\n",
"print(f\"选的表是: {select_table_result_message4.content}\")"
]
},
{
"cell_type": "markdown",
"id": "4f3fd089-613b-4bdd-95ac-c87f89c0fc61",
"metadata": {},
"source": [
"## 问表 prompt 调整\n",
"问表的 prompt template 必须包含:\n",
"1. {schema} - 表的 schema 信息, gbi系统使用构建 NL2Sql 对象的 table_schemas 成员来填充该占位符,在构造函数中需要填充该参数。\n",
"2. {instrument} - 列选限制的信息, gbi系统会使用 `NL2Sql.__call__` 函数的 Message 中的 column_constraint 参数来填充该占位符\n",
"3. {kg} - 知识, gbi系统使用构建 NL2Sql 对象的 knowledge 成员来填充该占位符,在构造函数中需要填充该参数。\n",
"4. {date} - 时间, gbi系统会自动填充该占位符, 用户不需要提供\n",
"5. {history_prompt} - 历史, gbi系统会使用 `NL2Sql.__call__` 函数的 Message 中的 session 参数来填充该占位符\n",
"6. {query} - 当前问题, gbi系统会使用 `NL2Sql.__call__` 函数的 Message 中的 query 参数来填充该占位符\n",
"\n",
"\n",
"注意: {schema}, {instrument}, {kg}, {date}, {history_prompt}, {query} 表示的是占位符,**用户不需要在自定义的 prompt template 中将这些值填充上**,gbi 系统会自动根据 `NL2Sql.__call__` 函数中提交的参数 或者 Nl2sql 的成员变量 进行填充这些占位符,从而产生最后的给大模型的 prompt。注意 prompt template 和 prompt 的区别。\n",
"\n",
"* prompt template - 是带有占位符的 prompt, gbi 会根据具体参数填充到这些占位符,形成最终的 prompt\n",
"* promt - 是将 prompt template 填充完占位符的结果。\n",
"\n",
"用户可以使用这些占位符重新设置自己的 prompt 模版,从而达到修改 prompt 的目的。\n",
"具体请参考下面的 prompt template 示例:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "323fbe75-62ca-44ab-9ca2-9f747939a2b5",
"metadata": {},
"outputs": [],
"source": [
"NL2SQL_PROMPT_TEMPLATE = \"\"\"\n",
" MySql 表 Schema 如下:\n",
" {schema}\n",
" 请根据用户当前问题,联系历史信息,仅编写1个sql,其中 sql 语句需要使用```sql ```这种 markdown 形式给出。\n",
" 请参考列选信息:\n",
" {instrument}\n",
" 请参考知识:\n",
" {kg}\n",
" 当前时间:{date}\n",
" 历史信息如下:\n",
" {history_prompt}\n",
" 当前问题:\"{query}\"\n",
" 回答:\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "52436f03-e01c-456a-aaa0-5a7f1afcd9d2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sql: SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n",
"-----------------\n",
"llm result: ```sql\n",
"SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n",
"```\n"
]
}
],
"source": [
"\n",
"gbi_nl2sql5 = appbuilder.NL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas, prompt_template=NL2SQL_PROMPT_TEMPLATE)\n",
"nl2sql_result_message5 = gbi_nl2sql5(Message({\"query\": \"查看商品类别是水果的所有数据\"}))\n",
"print(f\"sql: {nl2sql_result_message5.content.sql}\")\n",
"print(\"-----------------\")\n",
"print(f\"llm result: {nl2sql_result_message5.content.llm_result}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}