import subprocess import sys subprocess.check_call([sys.executable, "-m", "pip", "install", 'fasttext']) import fasttext from typing import List import torch from torch import nn from transformers import PretrainedConfig from transformers import PreTrainedModel from huggingface_hub import hf_hub_download class FastTextConfig(PretrainedConfig): model_type = "fasttext-language-identification" def __init__( self, repo_id: str = "facebook/fasttext-language-identification", top_k: int = 1, **kwargs ): self.repo_id = repo_id self.top_k = top_k super().__init__(**kwargs) class FastTextModel(PreTrainedModel): config_class = FastTextConfig def __init__(self, config): super().__init__(config) self.model = FastText(config.repo_id) def forward(self, words: List[str], k=1) -> List[str]: return self.model(words, k=k) class FastText(nn.Module): def __init__(self, repo_id: str, filename: str = "model.bin", *args, **kwargs) -> None: super(FastText, self).__init__() self.ft = fasttext.load_model( hf_hub_download(repo_id=repo_id, filename=filename) ) word_vectors = torch.from_numpy(self.ft.get_input_matrix()) num_embeddings = word_vectors.size(0) # vocabulary size embedding_dim = word_vectors.size(1) # embedding size self.embeddings = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) def forward(self, text: str, k=1) -> List[str]: return self.ft.predict(text, k=k)