wrangle_wiki_qa.ipynb•2.74 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Wrangle WikiQA"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"import datasets\n",
"import pandas as pd\n",
"\n",
"pd.set_option(\"display.max_colwidth\", 1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = datasets.load_dataset(\"wiki_qa\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"split = \"train\"\n",
"raw_df = dataset[split].to_pandas()\n",
"raw_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rows = []\n",
"for question_id, group in raw_df.groupby([\"question_id\"], as_index=False):\n",
" questions = group[\"question\"].unique().tolist()\n",
" assert len(questions) == 1\n",
" question = questions[0]\n",
" document_titles = group[\"document_title\"].unique().tolist()\n",
" assert len(document_titles) == 1\n",
" document_title = document_titles[0]\n",
" document_text = \" \".join(group[\"answer\"].tolist())\n",
" texts_with_emphasis = []\n",
" for relevance_label, text in zip(group[\"label\"].to_list(), group[\"answer\"].to_list()):\n",
" is_relevant = relevance_label == 1\n",
" if is_relevant:\n",
" texts_with_emphasis.append(text.upper())\n",
" else:\n",
" texts_with_emphasis.append(text)\n",
" document_text_with_emphasis = \" \".join(texts_with_emphasis)\n",
" is_relevant = 1 in group[\"label\"].to_list()\n",
" rows.append(\n",
" {\n",
" \"query_id\": question_id,\n",
" \"query_text\": question,\n",
" \"document_title\": document_title,\n",
" \"document_text\": document_text,\n",
" \"document_text_with_emphasis\": document_text_with_emphasis,\n",
" \"relevant\": is_relevant,\n",
" }\n",
" )\n",
"df = pd.DataFrame(rows)\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_path = f\"wiki_qa-{split}.jsonl\"\n",
"with open(data_path, \"w\") as f:\n",
" for record in df.to_dict(orient=\"records\"):\n",
" f.write(json.dumps(record) + \"\\n\")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}