French-TV-transcript-NER / inference_transcript_ner.py
Pclanglais's picture
Update inference_transcript_ner.py
07df5e5 verified
import re
import pandas as pd
from tqdm.auto import tqdm
from transformers import pipeline
from transformers import AutoTokenizer
model_checkpoint = "Pclanglais/French-TV-transcript-NER"
token_classifier = pipeline(
"token-classification", model=model_checkpoint, aggregation_strategy="simple"
)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
def split_text(text, max_tokens=500):
# Split the text by newline characters
parts = text.split("\n")
chunks = []
current_chunk = ""
for part in parts:
# Add part to current chunk
if current_chunk:
temp_chunk = current_chunk + "\n" + part
else:
temp_chunk = part
# Tokenize the temporary chunk
num_tokens = len(tokenizer.tokenize(temp_chunk))
if num_tokens <= max_tokens:
current_chunk = temp_chunk
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = part
if current_chunk:
chunks.append(current_chunk)
# If no newlines were found and still exceeding max_tokens, split further
if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens:
long_text = chunks[0]
chunks = []
while len(tokenizer.tokenize(long_text)) > max_tokens:
split_point = len(long_text) // 2
while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]):
split_point += 1
# Ensure split_point does not go out of range
if split_point >= len(long_text):
split_point = len(long_text) - 1
chunks.append(long_text[:split_point].strip())
long_text = long_text[split_point:].strip()
if long_text:
chunks.append(long_text)
return chunks
complete_data = pd.read_parquet("[file with transcripts]")
print(complete_data)
classified_list = []
list_prompt = []
list_page = []
list_file = []
list_id = []
text_id = 1
for index, row in complete_data.iterrows():
prompt, current_file = str(row["corrected_text"]), row["identifier"]
prompt = re.sub("\n", " ¶ ", prompt)
# Tokenize the prompt and check if it exceeds 500 tokens
num_tokens = len(tokenizer.tokenize(prompt))
if num_tokens > 500:
# Split the prompt into chunks
chunks = split_text(prompt, max_tokens=500)
for chunk in chunks:
list_file.append(current_file)
list_prompt.append(chunk)
list_id.append(text_id)
else:
list_file.append(current_file)
list_prompt.append(prompt)
list_id.append(text_id)
text_id = text_id + 1
full_classification = []
batch_size = 4
for out in tqdm(token_classifier(list_prompt, batch_size=batch_size), total=len(list_prompt)/batch_size):
full_classification.append(out)
id_row = 0
for classification in full_classification:
try:
df = pd.DataFrame(classification)
df["identifier"] = list_file[id_row]
df["text_id"] = list_id[id_row]
df['word'] = df['word'].replace(' ¶ ', ' \n ', regex=True)
print(df)
classified_list.append(df)
except:
pass
id_row = id_row + 1
classified_list = pd.concat(classified_list)
# Display the DataFrame
print(classified_list)
classified_list.to_csv("result_transcripts.tsv", sep = "\t")