{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Document Classification (DC)\n", "\n", "*Note: This notebook assumes that you're familar with the basic usage of AutoTransformers. Go to the \"Getting started\" notebook to learn the basics.*\n", "\n", "This example demonstrates how to classify pages of a document using AutoTransformers." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install the Huggingface datasets package\n", "%pip install datasets" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "from autotransformers import AutoTransformer, DatasetLoader" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We will use images from the [funsd](https://huggingface.co/datasets/nielsr/funsd) dataset from Huggingface, which contains scans of various types of business documents. First, we download the dataset and get some image paths. Please note that we set some random classes for the pages just to demonstrate how DC with AT works." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading builder script: 100%|██████████| 4.54k/4.54k [00:00<00:00, 10.6MB/s]\n", "Downloading data: 100%|██████████| 16.8M/16.8M [00:02<00:00, 6.51MB/s]\n", "Generating train split: 149 examples [00:00, 158.81 examples/s]\n", "Generating test split: 50 examples [00:00, 161.39 examples/s]\n" ] } ], "source": [ "funsd = load_dataset(\"nielsr/funsd\", split=\"train\")\n", "filepaths = [sample[\"image_path\"] for sample in funsd][:6]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "As in the previous tutorials, we show a minimal dataset example below. For documents, each input consists of an image path and a list of bounding boxes around each word detected by an OCR scanner. The bounding boxes are given in the dict format used by `ocr_wrapper`. \n", "\n", "In document classification, each page gets a label. Thus, the labels in the dataset below are simply a string (compared to DIE where each BBox must be labeled)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# In this dataset, we only give a single bounding box per document in order to\n", "# keep the example small. In practice, a document usually has 100s of bounding boxes.\n", "dataset = {\n", " \"meta\": {\n", " \"name\": \"example_dc\",\n", " \"version\": \"1.0.0\",\n", " \"created_with\": \"wizard\",\n", " },\n", " \"config\": [\n", " {\n", " \"ocr\": \"google\",\n", " \"domain\": \"document\",\n", " \"type\": \"IDocument\",\n", " },\n", " {\n", " \"task_id\": \"dc_single\", \n", " \"classes\": [\"class_0\", \"class_1\"], \n", " \"type\": \"TSingleClassification\"\n", " },\n", " ],\n", " \"train\": [\n", " [\n", " {\n", " \"image\": filepaths[0],\n", " \"bboxes\": [\n", " {\n", " \"TLx\": 0.287,\n", " \"TLy\": 0.316,\n", " \"TRx\": 0.295,\n", " \"TRy\": 0.316,\n", " \"BRx\": 0.295,\n", " \"BRy\": 0.327,\n", " \"BLx\": 0.287,\n", " \"BLy\": 0.327,\n", " \"original_width\": 762,\n", " \"original_height\": 1000,\n", " },\n", " ],\n", " \"texts\": [\n", " \":\"\n", " ]\n", " },\n", " { \"value\": \"class_0\"},\n", " ],\n", " [\n", " {\n", " \"image\": filepaths[1],\n", " \"bboxes\": [\n", " {\n", " \"TLx\": 0.099,\n", " \"TLy\": 0.129,\n", " \"TRx\": 0.154,\n", " \"TRy\": 0.129,\n", " \"BRx\": 0.154,\n", " \"BRy\": 0.139,\n", " \"BLx\": 0.099,\n", " \"BLy\": 0.139,\n", " \"original_width\": 762,\n", " \"original_height\": 1000,\n", " },\n", " ],\n", " \"texts\": [\n", " \"Brand:\"\n", " ]\n", " },\n", " {\"value\": \"class_1\"},\n", " ],\n", " [\n", " {\n", " \"image\": filepaths[2],\n", " \"bboxes\": [\n", " {\n", " \"TLx\": 0.423,\n", " \"TLy\": 0.497,\n", " \"TRx\": 0.501,\n", " \"TRy\": 0.497,\n", " \"BRx\": 0.501,\n", " \"BRy\": 0.521,\n", " \"BLx\": 0.423,\n", " \"BLy\": 0.521,\n", " \"original_width\": 762,\n", " \"original_height\": 1000,\n", " },\n", " ],\n", " \"texts\": [\n", " \"29Mar\"\n", " ]\n", " },\n", " {\"value\": \"class_0\"},\n", " ],\n", " [\n", " {\n", " \"image\": filepaths[3],\n", " \"bboxes\": [\n", " {\n", " \"TLx\": 0.078,\n", " \"TLy\": 0.121,\n", " \"TRx\": 0.166,\n", " \"TRy\": 0.121,\n", " \"BRx\": 0.166,\n", " \"BRy\": 0.135,\n", " \"BLx\": 0.078,\n", " \"BLy\": 0.135,\n", " \"original_width\": 762,\n", " \"original_height\": 1000,\n", " },\n", " ],\n", " \"texts\": [\n", " \"SUBJECT:\"\n", " ]\n", " },\n", " {\"value\": \"class_1\"},\n", " ],\n", " ],\n", " \"test\": [\n", " [\n", " {\n", " \"image\": filepaths[4],\n", " \"bboxes\": [\n", " {\n", " \"TLx\": 0.779,\n", " \"TLy\": 0.084,\n", " \"TRx\": 0.84,\n", " \"TRy\": 0.084,\n", " \"BRx\": 0.84,\n", " \"BRy\": 0.095,\n", " \"BLx\": 0.779,\n", " \"BLy\": 0.095,\n", " \"original_width\": 762,\n", " \"original_height\": 1000,\n", " },\n", " ],\n", " \"texts\": [\n", " \"Revision\"\n", " ]\n", " },\n", " {\"value\": \"class_1\"},\n", " ]\n", " ],\n", "}" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "From here on, training a document model is as easy as training a text model in the previous examples: Just create a DatasetLoader and start training." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "dl = DatasetLoader(dataset)\n", "\n", "# Or create a DatasetLoader from a file\n", "# dl = DatasetLoader(\"path/to/my-dataset.json\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Training with documents takes significantly longer than for text-only data, since the AutoTransformer has to learn from both the text and the image. Therefore, don't be surprised if the AutoTransformer trained in this short example does not perform very well.\n", "\n", "We use the `DocumentModelV4` base model for this example (a DocumentModel trained by DeepOpinion), which is a decent choice for document information extraction, and requires a moderate amount of resources due to its internal architecture." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", "\u001b[A \n", "\u001b[A " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Metric Value\n", "-------- -------\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "Train: 100%|██████████| 1/1 [00:00<00:00, 6.06it/s]\n", "Train: 100%|██████████| 1/1 [00:00<00:00, 8.65it/s]\n", "Train: 100%|██████████| 1/1 [00:00<00:00, 8.99it/s]\n", "Train: 100%|██████████| 1/1 [00:00<00:00, 9.11it/s]\n", "Train: 100%|██████████| 1/1 [00:00<00:00, 9.02it/s]\n", "Train: 100%|██████████| 1/1 [00:00<00:00, 8.80it/s]\n", "Train: 100%|██████████| 1/1 [00:00<00:00, 8.57it/s]\n", "Train: 100%|██████████| 1/1 [00:00<00:00, 8.83it/s]\n", "Train: 100%|██████████| 1/1 [00:00<00:00, 8.64it/s]\n", "Train: 100%|██████████| 1/1 [00:00<00:00, 8.65it/s]\n", " \n", " \n", "Train: 0%| | 0/1 [00:00