|
import subprocess |
|
import sys |
|
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", 'fasttext']) |
|
|
|
|
|
import fasttext |
|
from typing import List |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from transformers import PretrainedConfig |
|
from transformers import PreTrainedModel |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
class FastTextConfig(PretrainedConfig): |
|
model_type = "fasttext-language-identification" |
|
|
|
def __init__( |
|
self, |
|
repo_id: str = "facebook/fasttext-language-identification", |
|
top_k: int = 1, |
|
**kwargs |
|
): |
|
self.repo_id = repo_id |
|
self.top_k = top_k |
|
super().__init__(**kwargs) |
|
|
|
|
|
class FastTextModel(PreTrainedModel): |
|
config_class = FastTextConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = FastText(config.repo_id) |
|
|
|
def forward(self, words: List[str], k=1) -> List[str]: |
|
return self.model(words, k=k) |
|
|
|
|
|
class FastText(nn.Module): |
|
def __init__(self, repo_id: str, filename: str = "model.bin", *args, **kwargs) -> None: |
|
super(FastText, self).__init__() |
|
self.ft = fasttext.load_model( |
|
hf_hub_download(repo_id=repo_id, filename=filename) |
|
) |
|
word_vectors = torch.from_numpy(self.ft.get_input_matrix()) |
|
|
|
num_embeddings = word_vectors.size(0) |
|
embedding_dim = word_vectors.size(1) |
|
self.embeddings = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) |
|
|
|
def forward(self, text: str, k=1) -> List[str]: |
|
return self.ft.predict(text, k=k) |
|
|