File size: 4,909 Bytes
695a656
 
e145bd5
 
695a656
e145bd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
---
license: gpl-3.0
language:
- en
---
# TL;DR

The abstract of the paper states that: 
> Charts are very popular for analyzing data, visualizing key insights and answering complex reasoning questions about data. To facilitate chart-based data analysis using natural language, several downstream tasks have been introduced recently such as chart question answering and chart summarization. However, most of the methods that solve these tasks use pretraining on language or vision-language tasks that do not attempt to explicitly model the structure of the charts (e.g., how data is visually encoded and how chart elements are related to each other). To address this, we first build a large corpus of charts covering a wide variety of topics and visual styles. We then present UniChart, a pretrained model for chart comprehension and reasoning. UniChart encodes the relevant text, data, and visual elements of charts and then uses a chart-grounded text decoder to generate the expected output in natural language. We propose several chart-specific pretraining tasks that include: (i) low-level tasks to extract the visual elements (e.g., bars, lines) and data from charts, and (ii) high-level tasks to acquire chart understanding and reasoning skills. We find that pretraining the model on a large corpus with chart-specific low- and high-level tasks followed by finetuning on three down-streaming tasks results in state-of-the-art performance on three downstream tasks.

# Web Demo
If you wish to quickly try our models, you can access our public web demoes hosted on the Hugging Face Spaces platform with a friendly interface!

| Tasks  | Web Demo |
| ------------- | ------------- |
| Base Model (Best for Chart Summarization and Data Table Generation)  | [UniChart-Base](https://huggingface.co/spaces/ahmed-masry/UniChart-Base) |
| Chart Question Answering  | [UniChart-ChartQA](https://huggingface.co/spaces/ahmed-masry/UniChart-ChartQA) |

The input prompt for Chart summarization is **\<summarize_chart\>** and Data Table Generation is **\<extract_data_table\>**


# Inference
You can easily use our models for inference with the huggingface library! 
You just need to do the following:
1. Change _model_name_ to your prefered checkpoint.
2. Chage the _imag_path_ to your chart example image path on your system
3. Write the _input_prompt_ based on your prefered task as shown in the table below.

| Task  | Input Prompt |
| ------------- | ------------- |
| Chart Question Answering  | \<chartqa\> question \<s_answer\>  |
| Open Chart Question Answering  | \<opencqa\> question \<s_answer\>  |
| Chart Summarization  | \<summarize_chart\> \<s_answer\>  |
| Data Table Extraction  | \<extract_data_table\> \<s_answer\>  |

```
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch, os, re

torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_1.png')

model_name = "ahmed-masry/unichart-chartqa-960"
image_path = "/content/chart_example_1.png"
input_prompt = "<chartqa> What is the lowest value in blue bar? <s_answer>"

model = VisionEncoderDecoderModel.from_pretrained(model_name)
processor = DonutProcessor.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

image = Image.open(image_path).convert("RGB")
decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values

outputs = model.generate(
    pixel_values.to(device),
    decoder_input_ids=decoder_input_ids.to(device),
    max_length=model.decoder.config.max_position_embeddings,
    early_stopping=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    use_cache=True,
    num_beams=4,
    bad_words_ids=[[processor.tokenizer.unk_token_id]],
    return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = sequence.split("<s_answer>")[1].strip()
print(sequence)

```

# Contact
If you have any questions about this work, please contact **[Ahmed Masry](https://ahmedmasryku.github.io/)** using the following email addresses: **[email protected]** or **[email protected]**.

# Reference
Please cite our paper if you use our models or dataset in your research. 

```
@misc{masry2023unichart,
      title={UniChart: A Universal Vision-language Pretrained Model for Chart Comprehension and Reasoning}, 
      author={Ahmed Masry and Parsa Kavehzadeh and Xuan Long Do and Enamul Hoque and Shafiq Joty},
      year={2023},
      eprint={2305.14761},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
```