fasttext-language-identification / custom_fasttext.py
Hiveurban's picture
Upload model
fcced0c verified
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", 'fasttext'])
import fasttext
from typing import List
import torch
from torch import nn
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from huggingface_hub import hf_hub_download
class FastTextConfig(PretrainedConfig):
model_type = "fasttext-language-identification"
def __init__(
self,
repo_id: str = "facebook/fasttext-language-identification",
top_k: int = 1,
**kwargs
):
self.repo_id = repo_id
self.top_k = top_k
super().__init__(**kwargs)
class FastTextModel(PreTrainedModel):
config_class = FastTextConfig
def __init__(self, config):
super().__init__(config)
self.model = FastText(config.repo_id)
def forward(self, words: List[str], k=1) -> List[str]:
return self.model(words, k=k)
class FastText(nn.Module):
def __init__(self, repo_id: str, filename: str = "model.bin", *args, **kwargs) -> None:
super(FastText, self).__init__()
self.ft = fasttext.load_model(
hf_hub_download(repo_id=repo_id, filename=filename)
)
word_vectors = torch.from_numpy(self.ft.get_input_matrix())
num_embeddings = word_vectors.size(0) # vocabulary size
embedding_dim = word_vectors.size(1) # embedding size
self.embeddings = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
def forward(self, text: str, k=1) -> List[str]:
return self.ft.predict(text, k=k)