File size: 1,600 Bytes
fcced0c 93b5b4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", 'fasttext'])
import fasttext
from typing import List
import torch
from torch import nn
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from huggingface_hub import hf_hub_download
class FastTextConfig(PretrainedConfig):
model_type = "fasttext-language-identification"
def __init__(
self,
repo_id: str = "facebook/fasttext-language-identification",
top_k: int = 1,
**kwargs
):
self.repo_id = repo_id
self.top_k = top_k
super().__init__(**kwargs)
class FastTextModel(PreTrainedModel):
config_class = FastTextConfig
def __init__(self, config):
super().__init__(config)
self.model = FastText(config.repo_id)
def forward(self, words: List[str], k=1) -> List[str]:
return self.model(words, k=k)
class FastText(nn.Module):
def __init__(self, repo_id: str, filename: str = "model.bin", *args, **kwargs) -> None:
super(FastText, self).__init__()
self.ft = fasttext.load_model(
hf_hub_download(repo_id=repo_id, filename=filename)
)
word_vectors = torch.from_numpy(self.ft.get_input_matrix())
num_embeddings = word_vectors.size(0) # vocabulary size
embedding_dim = word_vectors.size(1) # embedding size
self.embeddings = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
def forward(self, text: str, k=1) -> List[str]:
return self.ft.predict(text, k=k)
|