Multitask ModelsΒΆ

So far, our AutoTransformer only had one task to solve, for example classifying a document. In AutoTransformers, we can however combine multiple tasks in a single model, saving space and computation time. This example will show how to train a model that performs both single-label classification and information extraction on a document.

[ ]:
from datasets import load_dataset

from autotransformers import AutoTransformer, DatasetLoader
from autotransformers.utils.bbox_utils import draw_bboxes

We will use the same document data as in the DIE example, but extend it by a second document classification task. We thus have 2 separate labels per sample, one for the document classification and one for the information extraction.

[ ]:
funsd = load_dataset("nielsr/funsd", split="train")
filepaths = [sample["image_path"] for sample in funsd][:6]
[ ]:
# In this dataset, we only give a single bounding box per document in order to
# keep the example small. In practice, a document usually has 100s of bounding boxes.
dataset = {
    "meta": {
        "name": "example_multitask",
        "version": "1.0.0",
        "created_with": "wizard",
    },
    "config": [
        {
            "ocr": "google",
            "domain": "document",
            "type": "IDocument",
        },
        [
            {
                "task_id": "task1",
                "classes": ["O", "HEADER", "QUESTION", "ANSWER"],
                "none_class_id": 0,
                "special_token_id": -100,
                "type": "TInformationExtraction",
            },
            #############################
            ## New classification task ##
            {
                "task_id": "task2",
                "classes": ["suggestion", "specification", "report"],
                "type": "TSingleClassification",
            },
            #############################
        ],
    ],
    "train": [
        [
            {
                "image": filepaths[0],
                "bboxes": [
                    {
                        "TLx": 0.287,
                        "TLy": 0.316,
                        "TRx": 0.295,
                        "TRy": 0.316,
                        "BRx": 0.295,
                        "BRy": 0.327,
                        "BLx": 0.287,
                        "BLy": 0.327,
                        "in_pixels": False,
                        "text": ":",
                        "label": None,
                    },
                ],
            },
            [
                {"value": ["QUESTION"]},  # Target value for the DIE task
                {"value": "suggestion"},  # Target value for the classification task
            ],
        ],
        [
            {
                "image": filepaths[1],
                "bboxes": [
                    {
                        "TLx": 0.099,
                        "TLy": 0.129,
                        "TRx": 0.154,
                        "TRy": 0.129,
                        "BRx": 0.154,
                        "BRy": 0.139,
                        "BLx": 0.099,
                        "BLy": 0.139,
                        "in_pixels": False,
                        "text": "Brand:",
                        "label": None,
                    },
                ],
            },
            [
                {"value": ["QUESTION"]},
                {"value": "specification"},
            ],
        ],
        [
            {
                "image": filepaths[2],
                "bboxes": [
                    {
                        "TLx": 0.423,
                        "TLy": 0.497,
                        "TRx": 0.501,
                        "TRy": 0.497,
                        "BRx": 0.501,
                        "BRy": 0.521,
                        "BLx": 0.423,
                        "BLy": 0.521,
                        "in_pixels": False,
                        "text": "29Mar",
                        "label": None,
                    },
                ],
            },
            [
                {"value": ["ANSWER"]},
                {"value": "report"}
            ],
        ],
        [
            {
                "image": filepaths[3],
                "bboxes": [
                    {
                        "TLx": 0.078,
                        "TLy": 0.121,
                        "TRx": 0.166,
                        "TRy": 0.121,
                        "BRx": 0.166,
                        "BRy": 0.135,
                        "BLx": 0.078,
                        "BLy": 0.135,
                        "in_pixels": False,
                        "text": "SUBJECT:",
                        "label": None,
                    },
                ],
            },
            [
                {"value": ["QUESTION"]},
                {"value": "report"}
            ],
        ],
    ],
    "test": [
        [
            {
                "image": filepaths[4],
                "bboxes": [
                    {
                        "TLx": 0.779,
                        "TLy": 0.084,
                        "TRx": 0.84,
                        "TRy": 0.084,
                        "BRx": 0.84,
                        "BRy": 0.095,
                        "BLx": 0.779,
                        "BLy": 0.095,
                        "in_pixels": False,
                        "text": "Revision",
                        "label": None,
                    },
                ],
            },
            [
                {"value": ["QUESTION"]},
                {"value": "specification"}
            ],
        ]
    ],
}

The DatasetLoader can be used as usual; it automatically detects all tasks from the dataset given.

[ ]:
dl = DatasetLoader(dataset)

# Or create a DatasetLoader from a file
# dl = DatasetLoader("path/to/my-dataset.json")
[ ]:
# In this example, we only train for one epoch to finish fast.
# In reality, you want to set this to a higher value for better results.
config = [
    ("engine/stop_condition/type", "MaxEpochs"),
    ("engine/stop_condition/value", 1),
]
at = AutoTransformer(config)

at.init(dataset_loader=dl, model_name_or_path="DocumentModelV4", path=".models/example05")
at.train(dl)

We try out the AutoTransformer by passing it a path to an image. Note that the prediction is now a 2-tuple, since the AutoTransformer has 2 tasks.

[ ]:
# Select a random document path for testing
test_document = filepaths[5]
document, (prediction_die, prediction_class) = at(test_document)
[ ]:
from IPython.display import display

print("Class: ",prediction_class.value)
with_bboxes = draw_bboxes(document.image, bboxes=document.bboxes, texts=prediction_die.value)
display(with_bboxes)