{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multitask Models\n",
    "\n",
    "So far, our AutoTransformer only had one task to solve, for example classifying a document. In AutoTransformers, we can however **combine multiple tasks in a single model**, saving space and computation time. This example will show how to train a model that performs both single-label classification and information extraction on a document."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "from autotransformers import AutoTransformer, DatasetLoader\n",
    "from autotransformers.utils.bbox_utils import draw_bboxes"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use the same document data as in the DIE example, but extend it by a second document classification task. We thus have 2 separate labels per sample, one for the document classification and one for the information extraction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "funsd = load_dataset(\"nielsr/funsd\", split=\"train\")\n",
    "filepaths = [sample[\"image_path\"] for sample in funsd][:6]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# In this dataset, we only give a single bounding box per document in order to\n",
    "# keep the example small. In practice, a document usually has 100s of bounding boxes.\n",
    "dataset = {\n",
    "    \"meta\": {\n",
    "        \"name\": \"example_multitask\",\n",
    "        \"version\": \"1.0.0\",\n",
    "        \"created_with\": \"wizard\",\n",
    "    },\n",
    "    \"config\": [\n",
    "        {\n",
    "            \"ocr\": \"google\",\n",
    "            \"domain\": \"document\",\n",
    "            \"type\": \"IDocument\",\n",
    "        },\n",
    "        [\n",
    "            {\n",
    "                \"task_id\": \"task1\",\n",
    "                \"classes\": [\"O\", \"HEADER\", \"QUESTION\", \"ANSWER\"],\n",
    "                \"none_class_id\": 0,\n",
    "                \"special_token_id\": -100,\n",
    "                \"type\": \"TInformationExtraction\",\n",
    "            },\n",
    "            #############################\n",
    "            ## New classification task ##\n",
    "            {\n",
    "                \"task_id\": \"task2\",\n",
    "                \"classes\": [\"suggestion\", \"specification\", \"report\"],\n",
    "                \"type\": \"TSingleClassification\",\n",
    "            },\n",
    "            #############################\n",
    "        ],\n",
    "    ],\n",
    "    \"train\": [\n",
    "        [\n",
    "            {\n",
    "                \"image\": filepaths[0],\n",
    "                \"bboxes\": [\n",
    "                    {\n",
    "                        \"TLx\": 0.287,\n",
    "                        \"TLy\": 0.316,\n",
    "                        \"TRx\": 0.295,\n",
    "                        \"TRy\": 0.316,\n",
    "                        \"BRx\": 0.295,\n",
    "                        \"BRy\": 0.327,\n",
    "                        \"BLx\": 0.287,\n",
    "                        \"BLy\": 0.327,\n",
    "                        \"in_pixels\": False,\n",
    "                        \"text\": \":\",\n",
    "                        \"label\": None,\n",
    "                    },\n",
    "                ],\n",
    "            },\n",
    "            [\n",
    "                {\"value\": [\"QUESTION\"]},  # Target value for the DIE task\n",
    "                {\"value\": \"suggestion\"},  # Target value for the classification task\n",
    "            ],\n",
    "        ],\n",
    "        [\n",
    "            {\n",
    "                \"image\": filepaths[1],\n",
    "                \"bboxes\": [\n",
    "                    {\n",
    "                        \"TLx\": 0.099,\n",
    "                        \"TLy\": 0.129,\n",
    "                        \"TRx\": 0.154,\n",
    "                        \"TRy\": 0.129,\n",
    "                        \"BRx\": 0.154,\n",
    "                        \"BRy\": 0.139,\n",
    "                        \"BLx\": 0.099,\n",
    "                        \"BLy\": 0.139,\n",
    "                        \"in_pixels\": False,\n",
    "                        \"text\": \"Brand:\",\n",
    "                        \"label\": None,\n",
    "                    },\n",
    "                ],\n",
    "            },\n",
    "            [\n",
    "                {\"value\": [\"QUESTION\"]},\n",
    "                {\"value\": \"specification\"},\n",
    "            ],\n",
    "        ],\n",
    "        [\n",
    "            {\n",
    "                \"image\": filepaths[2],\n",
    "                \"bboxes\": [\n",
    "                    {\n",
    "                        \"TLx\": 0.423,\n",
    "                        \"TLy\": 0.497,\n",
    "                        \"TRx\": 0.501,\n",
    "                        \"TRy\": 0.497,\n",
    "                        \"BRx\": 0.501,\n",
    "                        \"BRy\": 0.521,\n",
    "                        \"BLx\": 0.423,\n",
    "                        \"BLy\": 0.521,\n",
    "                        \"in_pixels\": False,\n",
    "                        \"text\": \"29Mar\",\n",
    "                        \"label\": None,\n",
    "                    },\n",
    "                ],\n",
    "            },\n",
    "            [\n",
    "                {\"value\": [\"ANSWER\"]},\n",
    "                {\"value\": \"report\"}\n",
    "            ],\n",
    "        ],\n",
    "        [\n",
    "            {\n",
    "                \"image\": filepaths[3],\n",
    "                \"bboxes\": [\n",
    "                    {\n",
    "                        \"TLx\": 0.078,\n",
    "                        \"TLy\": 0.121,\n",
    "                        \"TRx\": 0.166,\n",
    "                        \"TRy\": 0.121,\n",
    "                        \"BRx\": 0.166,\n",
    "                        \"BRy\": 0.135,\n",
    "                        \"BLx\": 0.078,\n",
    "                        \"BLy\": 0.135,\n",
    "                        \"in_pixels\": False,\n",
    "                        \"text\": \"SUBJECT:\",\n",
    "                        \"label\": None,\n",
    "                    },\n",
    "                ],\n",
    "            },\n",
    "            [\n",
    "                {\"value\": [\"QUESTION\"]},\n",
    "                {\"value\": \"report\"}\n",
    "            ],\n",
    "        ],\n",
    "    ],\n",
    "    \"test\": [\n",
    "        [\n",
    "            {\n",
    "                \"image\": filepaths[4],\n",
    "                \"bboxes\": [\n",
    "                    {\n",
    "                        \"TLx\": 0.779,\n",
    "                        \"TLy\": 0.084,\n",
    "                        \"TRx\": 0.84,\n",
    "                        \"TRy\": 0.084,\n",
    "                        \"BRx\": 0.84,\n",
    "                        \"BRy\": 0.095,\n",
    "                        \"BLx\": 0.779,\n",
    "                        \"BLy\": 0.095,\n",
    "                        \"in_pixels\": False,\n",
    "                        \"text\": \"Revision\",\n",
    "                        \"label\": None,\n",
    "                    },\n",
    "                ],\n",
    "            },\n",
    "            [\n",
    "                {\"value\": [\"QUESTION\"]},\n",
    "                {\"value\": \"specification\"}\n",
    "            ],\n",
    "        ]\n",
    "    ],\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `DatasetLoader` can be used as usual; it automatically detects all tasks from the dataset given."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = DatasetLoader(dataset)\n",
    "\n",
    "# Or create a DatasetLoader from a file\n",
    "# dl = DatasetLoader(\"path/to/my-dataset.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# In this example, we only train for one epoch to finish fast. \n",
    "# In reality, you want to set this to a higher value for better results.\n",
    "config = [\n",
    "    (\"engine/stop_condition/type\", \"MaxEpochs\"),\n",
    "    (\"engine/stop_condition/value\", 1),\n",
    "]\n",
    "at = AutoTransformer(config)\n",
    "\n",
    "at.init(dataset_loader=dl, model_name_or_path=\"DocumentModelV4\", path=\".models/example05\")\n",
    "at.train(dl)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We try out the AutoTransformer by passing it a path to an image. Note that the prediction is now a 2-tuple, since the AutoTransformer has 2 tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select a random document path for testing\n",
    "test_document = filepaths[5]\n",
    "document, (prediction_die, prediction_class) = at(test_document)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display\n",
    "\n",
    "print(\"Class: \",prediction_class.value)\n",
    "with_bboxes = draw_bboxes(document.image, bboxes=document.bboxes, texts=prediction_die.value)\n",
    "display(with_bboxes)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "83ee97f1e4ad98a710574577955c6720418d3d8f987616cd4f238f891737d017"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}