Spaces:
Sleeping
Sleeping
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
import os | |
# Disable numba caching | |
os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache" | |
os.environ["NUMBA_DISABLE_JIT"] = "1" | |
def nllb(): | |
""" | |
Load and return the NLLB (No Language Left Behind) model and tokenizer. | |
This function loads the NLLB-200-distilled-1.3B model and tokenizer from Hugging Face's Transformers library. | |
The model is configured to use a GPU if available, otherwise it defaults to CPU. | |
Returns: | |
tuple: A tuple containing the loaded model and tokenizer. | |
- model (transformers.AutoModelForSeq2SeqLM): The loaded NLLB model. | |
- tokenizer (transformers.AutoTokenizer): The loaded tokenizer. | |
Example usage: | |
model, tokenizer = nllb() | |
""" | |
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load the tokenizer and model | |
# Set Hugging Face cache directory | |
# Ensure the cache directory exists and has the correct permissions | |
os.environ['HF_HOME'] = '/app/cache/huggingface' | |
os.environ['TRANSFORMERS_CACHE'] = '/app/cache/huggingface' | |
# Load models | |
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-1.3B") | |
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-1.3B").to(device) | |
return model, tokenizer | |
def nllb_translate(model, tokenizer, article, language): | |
""" | |
Translate an article using the NLLB model and tokenizer. | |
Args: | |
model (transformers.AutoModelForSeq2SeqLM): The NLLB model to use for translation. | |
Example: model, tokenizer = nllb() | |
tokenizer (transformers.AutoTokenizer): The tokenizer to use with the NLLB model. | |
Example: model, tokenizer = nllb() | |
article (str): The article text to be translated. | |
Example: "This is a sample article." | |
language (str): The target language for translation. Must be either 'spanish' or 'english'. | |
Example: "spanish" | |
Returns: | |
str: The translated text. | |
Example: "Este es un artículo de muestra." | |
""" | |
try: | |
# Tokenize the text | |
inputs = tokenizer(article, return_tensors="pt") | |
# Move the tokenized inputs to the same device as the model | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
if language == "es": | |
translated_tokens = model.generate( | |
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["spa_Latn"], max_length=30 | |
) | |
elif language == "en": | |
translated_tokens = model.generate( | |
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], max_length=30 | |
) | |
else: | |
raise ValueError("Unsupported language. Use 'es' or 'en'.") | |
# Decode the translation | |
text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
return text | |
except Exception as e: | |
print(f"Error during translation: {e}") | |
return "Translation failed" | |