import transformers import re from transformers import AutoTokenizer, pipeline import torch import html import gradio as gr import tempfile import os import pandas as pd # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" # Load models editorial_model = "LLMDH/Estienne" bibliography_model = "PleIAs/Bibliography-Formatter" bibliography_style = "PleIAs/Bibliography-Classifier" tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512) editorial_classifier = pipeline( "token-classification", model=editorial_model, aggregation_strategy="simple", device=device ) bibliography_classifier = pipeline( "token-classification", model=bibliography_model, aggregation_strategy="simple", device=device ) # Helper functions def preprocess_text(text): text = re.sub(r'<[^>]+>', '', text) text = re.sub(r'\n', ' ', text) text = re.sub(r'\s+', ' ', text) return text.strip() def split_text(text, max_tokens=500): parts = text.split("\n") chunks = [] current_chunk = "" for part in parts: temp_chunk = current_chunk + "\n" + part if current_chunk else part num_tokens = len(tokenizer.tokenize(temp_chunk)) if num_tokens <= max_tokens: current_chunk = temp_chunk else: if current_chunk: chunks.append(current_chunk) current_chunk = part if current_chunk: chunks.append(current_chunk) if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: long_text = chunks[0] chunks = [] while len(tokenizer.tokenize(long_text)) > max_tokens: split_point = len(long_text) // 2 while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): split_point += 1 if split_point >= len(long_text): split_point = len(long_text) - 1 chunks.append(long_text[:split_point].strip()) long_text = long_text[split_point:].strip() if long_text: chunks.append(long_text) return chunks def disambiguate_bibtex_ids(bibtex_entries): id_count = {} disambiguated_entries = [] for entry in bibtex_entries: # Extract the current ID match = re.search(r'@\w+{(\w+),', entry) if not match: disambiguated_entries.append(entry) continue original_id = match.group(1) # Check if this ID has been seen before if original_id in id_count: id_count[original_id] += 1 new_id = f"{original_id}{chr(96 + id_count[original_id])}" # 'a', 'b', 'c', etc. new_entry = re.sub(r'(@\w+{)(\w+)(,)', f'\\1{new_id}\\3', entry, 1) disambiguated_entries.append(new_entry) else: id_count[original_id] = 0 disambiguated_entries.append(entry) return disambiguated_entries def remove_punctuation(text): return re.sub(r'[^\w\s]', '', text) def extract_year(text): year_match = re.search(r'\b(\d{4})\b', text) return year_match.group(1) if year_match else None def create_bibtex_entry(data): if 'journal' in data: entry_type = 'article' elif 'booktitle' in data: entry_type = 'inproceedings' else: entry_type = 'book' none_content = data.pop('none', '') year = extract_year(none_content) if year and 'year' not in data: data['year'] = year if "year" in data: match_year = re.search(r'(\d{4})', data['year']) if match_year: data['year'] = match_year.group(1) year = data['year'] else: data.pop('year', '') #Pages conformity. if 'pages' in data: match = re.search(r'(\d+(-\d+)?)', data['pages']) if match: data['pages'] = match.group(1) else: data.pop('pages', '') author_words = data.get('author', '').split() first_author = author_words[0] if author_words else 'unknown' bibtex_id = f"{first_author}{year}" if year else first_author bibtex_id = remove_punctuation(bibtex_id.lower()) bibtex = f"@{entry_type}{{{bibtex_id},\n" for key, value in data.items(): if value.strip(): if key in ['volume', 'year']: value = remove_punctuation(value) if key == 'pages': value = value.replace('p. ', '') if key != "separator": bibtex += f" {key.lower()} = {{{value.strip()}}},\n" bibtex = bibtex.rstrip(',\n') + "\n}" return bibtex def save_bibtex(bibtex_content): with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.bib') as temp_file: temp_file.write(bibtex_content) return temp_file.name class CombinedProcessor: def process(self, user_message): #Precaution to reinforce bibliography detection. editorial_text = "Bibliography\n" + user_message #Our fix for the lack of newline in deberta editorial_text = re.sub("\n", " ¶ ", editorial_text) print(editorial_text) num_tokens = len(tokenizer.tokenize(editorial_text)) batch_prompts = split_text(editorial_text, max_tokens=500) if num_tokens > 500 else [editorial_text] editorial_out = editorial_classifier(batch_prompts) editorial_df = pd.concat([pd.DataFrame(classification) for classification in editorial_out]) # Filter out only bibliography entries bibliography_entries = editorial_df[editorial_df['entity_group'] == 'bibliography']['word'].tolist() bibtex_entries = [] list_style = [] for entry in bibliography_entries: print(entry) entry = re.sub(r'- ?[\n¶] ?', r'', entry) entry = re.sub(r' ?[\n¶] ?', r' ', entry) #style = pd.DataFrame(style_classifier(entry, truncation=True, padding=True, top_k=1)) #list_style.append(style) entry = re.sub(r'\s*([;:,\.])\s*', r' \1 ', entry) #print(entry) bib_out = bibliography_classifier(entry) bib_df = pd.DataFrame(bib_out) bibtex_data = {} current_entity = None for _, row in bib_df.iterrows(): entity_group = row['entity_group'] word = row['word'] if entity_group != 'None': if entity_group in bibtex_data: print(entity_group) if entity_group == "author": bibtex_data[entity_group] += ', ' + word else: bibtex_data[entity_group] += ' ' + word else: bibtex_data[entity_group] = word current_entity = entity_group else: if current_entity: if current_entity == "author": bibtex_data[current_entity] += ', ' + word else: bibtex_data[current_entity] += ' ' + word else: bibtex_data['None'] = bibtex_data.get('None', '') + ' ' + word bibtex_entry = create_bibtex_entry(bibtex_data) bibtex_entries.append(bibtex_entry) #list_style = pd.concat(list_style) #list_style = list_style.groupby('label')['score'].mean().sort_values(ascending=False).reset_index() #top_style = list_style.iloc[0]['label'] #top_style_score = list_style.iloc[0]['score'] # Create the style information string #style_info = f"Top bibliography style: {top_style} (Mean score: {top_style_score:.6f})" # Join BibTeX entries bibtex_content = "\n\n".join(bibtex_entries) #return style_info, bibtex_content return bibtex_content # Create the processor instance processor = CombinedProcessor() # Define the Gradio interface with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: gr.HTML("""