File size: 5,804 Bytes
13cf51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47dd29e
13cf51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47dd29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543ef5a
6c3bf25
543ef5a
07ae0e9
543ef5a
47dd29e
07ae0e9
47dd29e
688fdc8
47dd29e
 
 
 
 
688fdc8
47dd29e
 
 
 
 
 
f3b05ad
 
9013578
35dc99e
f3b05ad
35dc99e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from transformers import (
    AutoConfig,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
    squad_convert_examples_to_features
)

from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample
from transformers.data.metrics.squad_metrics import compute_predictions_logits
import streamlit as st
import gradio as gr
import json
import torch
import time
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

model_checkpoint = "akdeniz27/roberta-base-cuad"

def run_prediction(question_texts, context_text, model_path):
    max_seq_length = 512
    doc_stride = 256
    n_best_size = 1
    max_query_length = 64
    max_answer_length = 512
    do_lower_case = False
    null_score_diff_threshold = 0.0
    def to_list(tensor):
        return tensor.detach().cpu().tolist()
    config_class, model_class, tokenizer_class = (
        AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer)
    config = config_class.from_pretrained(model_path)
    tokenizer = tokenizer_class.from_pretrained(
        model_path, do_lower_case=True, use_fast=False)
    model = model_class.from_pretrained(model_path, config=config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    processor = SquadV2Processor()
    examples = []
    for i, question_text in enumerate(question_texts):
        example = SquadExample(
            qas_id=str(i),
            question_text=question_text,
            context_text=context_text,
            answer_text=None,
            start_position_character=None,
            title="Predict",
            answers=None,
        )
        examples.append(example)
    features, dataset = squad_convert_examples_to_features(
        examples=examples,
        tokenizer=tokenizer,
        max_seq_length=max_seq_length,
        doc_stride=doc_stride,
        max_query_length=max_query_length,
        is_training=False,
        return_dataset="pt",
        threads=1,
    )
    eval_sampler = SequentialSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=10)
    all_results = []
    for batch in eval_dataloader:
        model.eval()
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
            }
            example_indices = batch[3]
            outputs = model(**inputs)
            for i, example_index in enumerate(example_indices):
                eval_feature = features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                output = [to_list(output[i]) for output in outputs.to_tuple()]
                start_logits, end_logits = output
                result = SquadResult(unique_id, start_logits, end_logits)
                all_results.append(result)
    final_predictions = compute_predictions_logits(
        all_examples=examples,
        all_features=features,
        all_results=all_results,
        n_best_size=n_best_size,
        max_answer_length=max_answer_length,
        do_lower_case=do_lower_case,
        output_prediction_file=None,
        output_nbest_file=None,
        output_null_log_odds_file=None,
        verbose_logging=False,
        version_2_with_negative=True,
        null_score_diff_threshold=null_score_diff_threshold,
        tokenizer=tokenizer
    )
    return final_predictions

@st.cache(allow_output_mutation=True)
def load_model():
    model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint , use_fast=False)
    return model, tokenizer
    
@st.cache(allow_output_mutation=True)
def load_questions():
	with open('test.json') as json_file:
		data = json.load(json_file)
	questions = []
	for i, q in enumerate(data['data'][0]['paragraphs'][0]['qas']):
		question = data['data'][0]['paragraphs'][0]['qas'][i]['question']
		questions.append(question)
	return questions
	
@st.cache(allow_output_mutation=True)
def load_contracts():
	with open('test.json') as json_file:
		data = json.load(json_file)
	contracts = []
	for i, q in enumerate(data['data']):
		contract = ' '.join(data['data'][i]['paragraphs'][0]['context'].split())
		contracts.append(contract)
	return contracts
	
model, tokenizer = load_model()
questions = load_questions()
contracts = load_contracts()
contract = contracts[0]

st.header("πŸ“š Question Answering in Contract Understanding Atticus Dataset (CUAD)")
st.image("contract_review.png")

selected_question = st.selectbox('πŸ“‘ Choose one of the queries from the CUAD dataset or πŸ“ write a legal contract and see if the model can answer correctly: ', questions)

question_set = [questions[0], selected_question]
contract_type = st.radio("Select Contract", ("Sample Contract", "New Contract"))

if contract_type == "Sample Contract":
	sample_contract_num = st.slider("Select Sample Contract #")
	contract = contracts[sample_contract_num]
	with st.expander(f"Sample Contract #{sample_contract_num}"):
		st.write(contract)
else:
	contract = st.text_area("Input New Contract", "", height=256)
Run_Button = st.button("Run", key=None)
if Run_Button == True and not len(contract)==0 and not len(question_set)==0:
	predictions = run_prediction(question_set, contract, 'akdeniz27/roberta-base-cuad')
	
	for i, p in enumerate(predictions):
		if i != 0: st.write(f"Question: {question_set[int(p)]}\n\nAnswer: {predictions[p]}\n\n")


st.write("πŸ€—")
st.write("Based on Streamlit code of https://huggingface.co/spaces/akdeniz27/contract-understanding-atticus-dataset-demo")
st.write("Model: akdeniz27/roberta-base-cuad")
st.write("Project: https://www.atticusprojectai.org/cuad")