Multitask ModelsΒΆ
So far, our AutoTransformer only had one task to solve, for example classifying a document. In AutoTransformers, we can however combine multiple tasks in a single model, saving space and computation time. This example will show how to train a model that performs both single-label classification and information extraction on a document.
[ ]:
from datasets import load_dataset
from autotransformers import AutoTransformer, DatasetLoader
from autotransformers.utils.bbox_utils import draw_bboxes
We will use the same document data as in the DIE example, but extend it by a second document classification task. We thus have 2 separate labels per sample, one for the document classification and one for the information extraction.
[ ]:
funsd = load_dataset("nielsr/funsd", split="train")
filepaths = [sample["image_path"] for sample in funsd][:6]
[ ]:
# In this dataset, we only give a single bounding box per document in order to
# keep the example small. In practice, a document usually has 100s of bounding boxes.
dataset = {
"meta": {
"name": "example_multitask",
"version": "1.0.0",
"created_with": "wizard",
},
"config": [
{
"ocr": "google",
"domain": "document",
"type": "IDocument",
},
[
{
"task_id": "task1",
"classes": ["O", "HEADER", "QUESTION", "ANSWER"],
"none_class_id": 0,
"special_token_id": -100,
"type": "TInformationExtraction",
},
#############################
## New classification task ##
{
"task_id": "task2",
"classes": ["suggestion", "specification", "report"],
"type": "TSingleClassification",
},
#############################
],
],
"train": [
[
{
"image": filepaths[0],
"bboxes": [
{
"TLx": 0.287,
"TLy": 0.316,
"TRx": 0.295,
"TRy": 0.316,
"BRx": 0.295,
"BRy": 0.327,
"BLx": 0.287,
"BLy": 0.327,
"in_pixels": False,
"text": ":",
"label": None,
},
],
},
[
{"value": ["QUESTION"]}, # Target value for the DIE task
{"value": "suggestion"}, # Target value for the classification task
],
],
[
{
"image": filepaths[1],
"bboxes": [
{
"TLx": 0.099,
"TLy": 0.129,
"TRx": 0.154,
"TRy": 0.129,
"BRx": 0.154,
"BRy": 0.139,
"BLx": 0.099,
"BLy": 0.139,
"in_pixels": False,
"text": "Brand:",
"label": None,
},
],
},
[
{"value": ["QUESTION"]},
{"value": "specification"},
],
],
[
{
"image": filepaths[2],
"bboxes": [
{
"TLx": 0.423,
"TLy": 0.497,
"TRx": 0.501,
"TRy": 0.497,
"BRx": 0.501,
"BRy": 0.521,
"BLx": 0.423,
"BLy": 0.521,
"in_pixels": False,
"text": "29Mar",
"label": None,
},
],
},
[
{"value": ["ANSWER"]},
{"value": "report"}
],
],
[
{
"image": filepaths[3],
"bboxes": [
{
"TLx": 0.078,
"TLy": 0.121,
"TRx": 0.166,
"TRy": 0.121,
"BRx": 0.166,
"BRy": 0.135,
"BLx": 0.078,
"BLy": 0.135,
"in_pixels": False,
"text": "SUBJECT:",
"label": None,
},
],
},
[
{"value": ["QUESTION"]},
{"value": "report"}
],
],
],
"test": [
[
{
"image": filepaths[4],
"bboxes": [
{
"TLx": 0.779,
"TLy": 0.084,
"TRx": 0.84,
"TRy": 0.084,
"BRx": 0.84,
"BRy": 0.095,
"BLx": 0.779,
"BLy": 0.095,
"in_pixels": False,
"text": "Revision",
"label": None,
},
],
},
[
{"value": ["QUESTION"]},
{"value": "specification"}
],
]
],
}
The DatasetLoader
can be used as usual; it automatically detects all tasks from the dataset given.
[ ]:
dl = DatasetLoader(dataset)
# Or create a DatasetLoader from a file
# dl = DatasetLoader("path/to/my-dataset.json")
[ ]:
# In this example, we only train for one epoch to finish fast.
# In reality, you want to set this to a higher value for better results.
config = [
("engine/stop_condition/type", "MaxEpochs"),
("engine/stop_condition/value", 1),
]
at = AutoTransformer(config)
at.init(dataset_loader=dl, model_name_or_path="DocumentModelV4", path=".models/example05")
at.train(dl)
We try out the AutoTransformer by passing it a path to an image. Note that the prediction is now a 2-tuple, since the AutoTransformer has 2 tasks.
[ ]:
# Select a random document path for testing
test_document = filepaths[5]
document, (prediction_die, prediction_class) = at(test_document)
[ ]:
from IPython.display import display
print("Class: ",prediction_class.value)
with_bboxes = draw_bboxes(document.image, bboxes=document.bboxes, texts=prediction_die.value)
display(with_bboxes)