{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Multitask Models\n", "\n", "So far, our AutoTransformer only had one task to solve, for example classifying a document. In AutoTransformers, we can however **combine multiple tasks in a single model**, saving space and computation time. This example will show how to train a model that performs both single-label classification and information extraction on a document." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "from autotransformers import AutoTransformer, DatasetLoader\n", "from autotransformers.utils.bbox_utils import draw_bboxes" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We will use the same document data as in the DIE example, but extend it by a second document classification task. We thus have 2 separate labels per sample, one for the document classification and one for the information extraction." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "funsd = load_dataset(\"nielsr/funsd\", split=\"train\")\n", "filepaths = [sample[\"image_path\"] for sample in funsd][:6]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# In this dataset, we only give a single bounding box per document in order to\n", "# keep the example small. In practice, a document usually has 100s of bounding boxes.\n", "dataset = {\n", " \"meta\": {\n", " \"name\": \"example_multitask\",\n", " \"version\": \"1.0.0\",\n", " \"created_with\": \"wizard\",\n", " },\n", " \"config\": [\n", " {\n", " \"ocr\": \"google\",\n", " \"domain\": \"document\",\n", " \"type\": \"IDocument\",\n", " },\n", " [\n", " {\n", " \"task_id\": \"task1\",\n", " \"classes\": [\"O\", \"HEADER\", \"QUESTION\", \"ANSWER\"],\n", " \"none_class_id\": 0,\n", " \"special_token_id\": -100,\n", " \"type\": \"TInformationExtraction\",\n", " },\n", " #############################\n", " ## New classification task ##\n", " {\n", " \"task_id\": \"task2\",\n", " \"classes\": [\"suggestion\", \"specification\", \"report\"],\n", " \"type\": \"TSingleClassification\",\n", " },\n", " #############################\n", " ],\n", " ],\n", " \"train\": [\n", " [\n", " {\n", " \"image\": filepaths[0],\n", " \"bboxes\": [\n", " {\n", " \"TLx\": 0.287,\n", " \"TLy\": 0.316,\n", " \"TRx\": 0.295,\n", " \"TRy\": 0.316,\n", " \"BRx\": 0.295,\n", " \"BRy\": 0.327,\n", " \"BLx\": 0.287,\n", " \"BLy\": 0.327,\n", " \"in_pixels\": False,\n", " \"text\": \":\",\n", " \"label\": None,\n", " },\n", " ],\n", " },\n", " [\n", " {\"value\": [\"QUESTION\"]}, # Target value for the DIE task\n", " {\"value\": \"suggestion\"}, # Target value for the classification task\n", " ],\n", " ],\n", " [\n", " {\n", " \"image\": filepaths[1],\n", " \"bboxes\": [\n", " {\n", " \"TLx\": 0.099,\n", " \"TLy\": 0.129,\n", " \"TRx\": 0.154,\n", " \"TRy\": 0.129,\n", " \"BRx\": 0.154,\n", " \"BRy\": 0.139,\n", " \"BLx\": 0.099,\n", " \"BLy\": 0.139,\n", " \"in_pixels\": False,\n", " \"text\": \"Brand:\",\n", " \"label\": None,\n", " },\n", " ],\n", " },\n", " [\n", " {\"value\": [\"QUESTION\"]},\n", " {\"value\": \"specification\"},\n", " ],\n", " ],\n", " [\n", " {\n", " \"image\": filepaths[2],\n", " \"bboxes\": [\n", " {\n", " \"TLx\": 0.423,\n", " \"TLy\": 0.497,\n", " \"TRx\": 0.501,\n", " \"TRy\": 0.497,\n", " \"BRx\": 0.501,\n", " \"BRy\": 0.521,\n", " \"BLx\": 0.423,\n", " \"BLy\": 0.521,\n", " \"in_pixels\": False,\n", " \"text\": \"29Mar\",\n", " \"label\": None,\n", " },\n", " ],\n", " },\n", " [\n", " {\"value\": [\"ANSWER\"]},\n", " {\"value\": \"report\"}\n", " ],\n", " ],\n", " [\n", " {\n", " \"image\": filepaths[3],\n", " \"bboxes\": [\n", " {\n", " \"TLx\": 0.078,\n", " \"TLy\": 0.121,\n", " \"TRx\": 0.166,\n", " \"TRy\": 0.121,\n", " \"BRx\": 0.166,\n", " \"BRy\": 0.135,\n", " \"BLx\": 0.078,\n", " \"BLy\": 0.135,\n", " \"in_pixels\": False,\n", " \"text\": \"SUBJECT:\",\n", " \"label\": None,\n", " },\n", " ],\n", " },\n", " [\n", " {\"value\": [\"QUESTION\"]},\n", " {\"value\": \"report\"}\n", " ],\n", " ],\n", " ],\n", " \"test\": [\n", " [\n", " {\n", " \"image\": filepaths[4],\n", " \"bboxes\": [\n", " {\n", " \"TLx\": 0.779,\n", " \"TLy\": 0.084,\n", " \"TRx\": 0.84,\n", " \"TRy\": 0.084,\n", " \"BRx\": 0.84,\n", " \"BRy\": 0.095,\n", " \"BLx\": 0.779,\n", " \"BLy\": 0.095,\n", " \"in_pixels\": False,\n", " \"text\": \"Revision\",\n", " \"label\": None,\n", " },\n", " ],\n", " },\n", " [\n", " {\"value\": [\"QUESTION\"]},\n", " {\"value\": \"specification\"}\n", " ],\n", " ]\n", " ],\n", "}" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The `DatasetLoader` can be used as usual; it automatically detects all tasks from the dataset given." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dl = DatasetLoader(dataset)\n", "\n", "# Or create a DatasetLoader from a file\n", "# dl = DatasetLoader(\"path/to/my-dataset.json\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# In this example, we only train for one epoch to finish fast. \n", "# In reality, you want to set this to a higher value for better results.\n", "config = [\n", " (\"engine/stop_condition/type\", \"MaxEpochs\"),\n", " (\"engine/stop_condition/value\", 1),\n", "]\n", "at = AutoTransformer(config)\n", "\n", "at.init(dataset_loader=dl, model_name_or_path=\"DocumentModelV4\", path=\".models/example05\")\n", "at.train(dl)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We try out the AutoTransformer by passing it a path to an image. Note that the prediction is now a 2-tuple, since the AutoTransformer has 2 tasks." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Select a random document path for testing\n", "test_document = filepaths[5]\n", "document, (prediction_die, prediction_class) = at(test_document)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import display\n", "\n", "print(\"Class: \",prediction_class.value)\n", "with_bboxes = draw_bboxes(document.image, bboxes=document.bboxes, texts=prediction_die.value)\n", "display(with_bboxes)" ] } ], "metadata": { "kernelspec": { "display_name": "env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "vscode": { "interpreter": { "hash": "83ee97f1e4ad98a710574577955c6720418d3d8f987616cd4f238f891737d017" } } }, "nbformat": 4, "nbformat_minor": 2 }