wrangle_ms_marco.ipynb•5.16 kB
{
"cells": [
{
"cell_type": "markdown",
"id": "9c9cd486",
"metadata": {},
"source": [
"# Wrangle MS-MARCO for Relevancy Classifications\n",
"\n",
"The purpose of this notebook is to both transform the MS-MARCO dataset into relevancy classification format and to document the quirks of the original data format. MS-MARCO is a benchmark dataset used to evaluated information retrieval systems."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d09abc4a-65a6-4c5e-94ab-6538f0c16a51",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"import pandas as pd\n",
"from datasets import load_dataset\n",
"\n",
"pd.set_option(\"display.max_colwidth\", 1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1c33889-c950-41fc-b2a5-fa8b4a305ff2",
"metadata": {},
"outputs": [],
"source": [
"split = \"train\"\n",
"version = \"v1.1\"\n",
"dataset = load_dataset(\"ms_marco\", version, split=split)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80688937",
"metadata": {},
"outputs": [],
"source": [
"raw_df = dataset.to_pandas()\n",
"raw_df.head()"
]
},
{
"cell_type": "markdown",
"id": "1ffbf239",
"metadata": {},
"source": [
"The \"wellFormedAnswers\" column contains only empty lists. This column is omitted from the final dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62ce9e45",
"metadata": {},
"outputs": [],
"source": [
"raw_df[\"wellFormedAnswers\"].apply(len).value_counts()"
]
},
{
"cell_type": "markdown",
"id": "86cbba10",
"metadata": {},
"source": [
"Explode each row of the original dataset so that the new dataset contains one row per query-context pair."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dbe246e5",
"metadata": {},
"outputs": [],
"source": [
"query_texts = []\n",
"query_ids = []\n",
"query_types = []\n",
"reference_responses = []\n",
"selections = []\n",
"document_texts = []\n",
"document_urls = []\n",
"for data in dataset:\n",
" document_data = data[\"passages\"]\n",
" selections_for_query = list(map(bool, document_data[\"is_selected\"]))\n",
" document_texts_for_query = document_data[\"passage_text\"]\n",
" document_urls_for_query = document_data[\"url\"]\n",
" assert (\n",
" len(selections_for_query) == len(document_texts_for_query) == len(document_urls_for_query)\n",
" )\n",
" num_documents_for_query = len(selections_for_query)\n",
" selections.extend(selections_for_query)\n",
" document_texts.extend(document_texts_for_query)\n",
" document_urls.extend(document_urls_for_query)\n",
" query_ids.extend([data[\"query_id\"]] * num_documents_for_query)\n",
" query_texts.extend([data[\"query\"]] * num_documents_for_query)\n",
" query_types.extend([data[\"query_type\"]] * num_documents_for_query)\n",
" reference_responses.extend([data[\"answers\"]] * num_documents_for_query)\n",
"df = pd.DataFrame(\n",
" {\n",
" \"query_id\": query_ids,\n",
" \"query_text\": query_texts,\n",
" \"query_type\": query_types,\n",
" \"relevant\": selections,\n",
" \"document_text\": document_texts,\n",
" \"document_url\": document_urls,\n",
" \"reference_responses\": reference_responses,\n",
" }\n",
")\n",
"df"
]
},
{
"cell_type": "markdown",
"id": "be648e50",
"metadata": {},
"source": [
"Compare the column names of the original dataset with the columns of the wrangled dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b8e62c1",
"metadata": {},
"outputs": [],
"source": [
"set(raw_df.columns).difference(df.columns)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "314b3411",
"metadata": {},
"outputs": [],
"source": [
"set(df.columns).difference(raw_df.columns)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bc73dc6e",
"metadata": {},
"outputs": [],
"source": [
"binary_relevance_classification_df = df[\n",
" [\"query_id\", \"query_text\", \"document_text\", \"document_url\", \"relevant\"]\n",
"]\n",
"binary_relevance_classification_df.head()"
]
},
{
"cell_type": "markdown",
"id": "068c1a59",
"metadata": {},
"source": [
"Write the data to a JSONL file."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a791a93",
"metadata": {},
"outputs": [],
"source": [
"data_path = f\"ms_marco-{version}-{split}.jsonl\"\n",
"with open(data_path, \"w\") as f:\n",
" for record in binary_relevance_classification_df.to_dict(orient=\"records\"):\n",
" f.write(json.dumps(record) + \"\\n\")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}