File size: 6,801 Bytes
14fa848
30d04c6
 
197e844
7551cdd
 
14fa848
30d04c6
 
 
b5dd5bc
 
30d04c6
b5dd5bc
30d04c6
b5dd5bc
30d04c6
 
 
 
 
 
7551cdd
 
 
 
 
 
 
197e844
 
 
 
 
 
 
 
 
 
 
7551cdd
 
197e844
7551cdd
197e844
 
c89a5c0
 
 
 
 
 
 
100e8bf
 
 
 
 
7edd56d
100e8bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c89a5c0
100e8bf
 
 
 
c89a5c0
 
 
100e8bf
 
 
 
c89a5c0
100e8bf
 
 
 
 
 
197e844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c89a5c0
197e844
 
 
 
c89a5c0
 
 
197e844
 
 
 
 
c89a5c0
197e844
 
 
 
 
 
 
 
 
 
 
5c82e3e
 
 
 
 
 
197e844
30d04c6
 
 
 
7edd56d
c89a5c0
7edd56d
30d04c6
 
 
 
 
 
 
 
 
 
7551cdd
 
30d04c6
 
 
b5dd5bc
 
 
 
30d04c6
 
 
 
7551cdd
30d04c6
 
 
 
 
 
 
7551cdd
30d04c6
7551cdd
30d04c6
7551cdd
30d04c6
7551cdd
30d04c6
7551cdd
 
30d04c6
7551cdd
30d04c6
 
 
 
 
100e8bf
30d04c6
c89a5c0
197e844
 
 
30d04c6
5c82e3e
 
 
 
 
7551cdd
 
 
 
5c82e3e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import streamlit as st
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForTokenClassification
import pandas as pd
from pprint import pprint 


@st.cache_resource()
def load_trained_model():
    
    tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
    model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
    # Mapping labels
    id2label = model.config.id2label
    # Print the label mapping
    print(f"Can recognise the following labels {id2label}")

    # Load the NER model and tokenizer from Hugging Face
    #ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
    ner_pipeline = pipeline("ner", model=model, tokenizer = tokenizer)
    return ner_pipeline


@st.cache_data()
def load_plod_cw_dataset():
    from datasets import load_dataset
    dataset = load_dataset("surrey-nlp/PLOD-CW")
    return dataset

def load_random_examples(dataset_name, num_examples=5):
    """
    Load random examples from the specified Hugging Face dataset.
    Args:
        dataset_name (str): The name of the dataset to load.
        num_examples (int): The number of random examples to load.
    Returns:
        pd.DataFrame: A DataFrame containing the random examples.
    """
    # Load the dataset
    
    dat = load_plod_cw_dataset()

    # Convert the dataset to a pandas DataFrame
    df = pd.DataFrame(dat['test'])
    
    # Select random examples
    random_examples = df.sample(n=1)

    tokens = random_examples.tokens
    ner_tags = random_examples.ner_tags

    return pd.DataFrame((tokens, ner_tags))


def render_entities(tokens, entities):
    """
    Renders a page with a 2-column table showing the entity corresponding to each token.
    """
    
    # Custom CSS for chilled and cool theme
    st.markdown("""
        <style>
        body {
            font-family: 'Arial', sans-serif;
            background-color: #f0f0f5;
            color: #333333;
        }
        table {
            width: 100%;
            border-collapse: collapse;
        }
        th, td {
            padding: 12px;
            text-align: left;
            border-bottom: 1px solid #dddddd;
        }
        th {
            background-color: #4CAF50;
            color: white;
            width: 16.66%;
        }
        tr:hover {
            background-color: #f5f5f5;
        }
        td {
            width: 16.66%;
        }
        </style>
        """, unsafe_allow_html=True)

    # Title and description
    st.title("Model predicted Token vs Entities Table")
    st.write("This table shows the entity corresponding to each token in a cool and chilled theme.")

    # Create the table
    table_data = {"Token": tokens, "Entity": entities}
    st.table(table_data)

def render_random_examples():
    """
    Render random examples from the PLOD-CW dataset in a Streamlit table.
    """
    # Load random examples
    
    # Custom CSS for chilled and cool theme
    st.markdown("""
        <style>
        body {
            font-family: 'Arial', sans-serif;
            background-color: #f0f0f5;
            color: #333333;
        }
        table {
            width: 100%;
            border-collapse: collapse;
        }
        th, td {
            padding: 12px;
            text-align: left;
            border-bottom: 1px solid #dddddd;
        }
        th {
            background-color: #4CAF50;
            color: white;
            width: 16.66%;
        }
        tr:hover {
            background-color: #f5f5f5;
        }
        td {
            width: 16.66%;
        }
        </style>
        """, unsafe_allow_html=True)

    # Title and description
    st.title("Random Examples from PLOD-CW")
    st.write("This table shows 1 random examples from the PLOD-CW dataset in a cool and chilled theme.")

    # Add a button to select a different set of random samples
    if st.button('Show another set of random examples'):
        st.session_state['random_examples'] = load_random_examples("surrey-nlp/PLOD-CW")

    # Load random examples if not already loaded
    if 'random_examples' not in st.session_state:
        st.session_state['random_examples'] = load_random_examples("surrey-nlp/PLOD-CW")

    # Display the table
    st.table(st.session_state['random_examples'])
def predict_using_trained(sentence):
    model = load_trained_model()

    entities = model(sentence)

    return entities

def prep_page():
    model = load_trained_model()

    # Streamlit app
    # Page configuration
    #st.set_page_config(page_title="NER Token Entities", layout="centered")

    st.title("Named Entity Recognition with BERT on PLOD-CW")
    st.write("Enter a sentence to see the named entities recognized by the model.")

    # Text input
    text = st.text_area("Enter your sentence here:")

    # Perform NER and display results
    if text:
        st.write("Entities recognized:")
        entities = model(text)

        pprint(entities)
    
        # Create a dictionary to map entity labels to colors
        label_colors = {
            'B-LF': 'lightblue',
            'B-O': 'lightgreen',
            'B-AC': 'lightcoral',
            'I-LF': 'lightyellow'
        }
    
        # Prepare the HTML output with styled entities
        def get_entity_html(text, entities):
            html = "<div>"
            last_idx = 0
            for entity in entities:
                start = entity['start']
                end = entity['end']
                label = entity['entity']
                entity_text = text[start:end]
                color = label_colors.get(label, 'lightgray')

                # Append the text before the entity
                html += text[last_idx:start].replace(" ", "<br>")
                # Append the entity with styling
                html += f'<div style="background-color: {color}; padding: 5px; border-radius: 3px; margin: 5px 0;">{entity_text}</div>'
                last_idx = end
                
            # Append any remaining text after the last entity
            html += text[last_idx:].replace(" ", "<br>")
            html += "</div>"
            return html
            
        # Generate and display the styled HTML
        styled_text = get_entity_html(text, entities)
        
        st.markdown(styled_text, unsafe_allow_html=True)

        render_entities(text, entities)

    render_random_examples()



if __name__ == '__main__':

    query_params = st.query_params
    if 'api' in query_params:
        sentence = query_params.get('sentence')
        entities = predict_using_trained(sentence)
        response = {"sentence" : sentence , "entities" : entities}
        pprint(response)

        st.write(response)
    else:
        prep_page()