File size: 2,567 Bytes
14fa848
30d04c6
 
 
14fa848
 
30d04c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import streamlit as st
import wandb
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForTokenClassification

x = st.slider('Select a value')
st.write(x, 'squared is', x * x)

@st.cache_resource()
def load_trained_model():
    
    tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-base-cased-sourav")
    model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-base-cased-sourav")
    # Mapping labels
    label_map = model.config.id2label
    # Print the label mapping
    print(label_map)

    # Load the NER model and tokenizer from Hugging Face
    #ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
    ner_pipeline = pipeline("ner", model=model, tokenizer = tokenizer)
    return ner_pipeline

def prep_page():
    model = load_trained_model()

    # Streamlit app
    st.title("Named Entity Recognition with BERT on PLOD-CW")
    st.write("Enter a sentence to see the named entities recognized by the model.")

    # Text input
    text = st.text_area("Enter your sentence here:")

    # Perform NER and display results
    if text:
        st.write("Entities recognized:")
        entities = model(text)
    
        # Create a dictionary to map entity labels to colors
        label_colors = {
            'ORG': 'lightblue',
            'PER': 'lightgreen',
            'LOC': 'lightcoral',
            'MISC': 'lightyellow'
        }
    
        # Prepare the HTML output with styled entities
        def get_entity_html(text, entities):
            html = ""
            last_idx = 0
            for entity in entities:
                start = entity['start']
                end = entity['end']
                label = entity['entity']
                entity_text = text[start:end]
                color = label_colors.get(label, 'lightgray')
            
                # Append the text before the entity
                html += text[last_idx:start]
                # Append the entity with styling
                html += f'<mark style="background-color: {color}; border-radius: 3px;">{entity_text}</mark>'
                last_idx = end
        
            # Append any remaining text after the last entity
            html += text[last_idx:]
            return html
    
        # Generate and display the styled HTML
        styled_text = get_entity_html(text, entities)
        
        st.markdown(styled_text, unsafe_allow_html=True)



if __name__ == '__main__':
    models = load_model_from_wandb()
    print(models)

    prep_page()