davidkim205's picture
Upload folder using huggingface_hub
577164e verified
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils.simple_bleu import simple_score
import torch
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-1.3B", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-1.3B")
def translate_ko2en(text):
batched_input = [text]
inputs = tokenizer(batched_input, return_tensors="pt", padding=True)
translated_tokens = model.generate(
**inputs.to(model.device), forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"]
)
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return result
def translate_en2ko(text):
batched_input = [text]
inputs = tokenizer(batched_input, return_tensors="pt", padding=True)
translated_tokens = model.generate(
**inputs.to(model.device), forced_bos_token_id=tokenizer.lang_code_to_id["kor_Hang"], max_new_tokens=2048)
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return result
def main():
while True:
text = input('>')
en_text = translate_ko2en(text)
ko_text = translate_en2ko(en_text)
print('en_text', en_text)
print('ko_text', ko_text)
print('score', simple_score(text, ko_text))
"""
>>? 3์ฒœ๋งŒ ๊ฐœ๊ฐ€ ๋„˜๋Š” ํŒŒ์ผ๊ณผ 250์–ต ๊ฐœ์˜ ํ† ํฐ์ด ์žˆ์Šต๋‹ˆ๋‹ค. Phi1.5์˜ ๋ฐ์ดํ„ฐ ์„ธํŠธ ๊ตฌ์„ฑ์— ์ ‘๊ทผํ•˜์ง€๋งŒ ์˜คํ”ˆ ์†Œ์Šค ๋ชจ๋ธ์ธ Mixtral 8x7B๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  Apache2.0 ๋ผ์ด์„ ์Šค์— ๋”ฐ๋ผ ๋ผ์ด์„ ์Šค๊ฐ€ ๋ถ€์—ฌ๋ฉ๋‹ˆ๋‹ค.
en_text There are over 30 million files and 250 billion tokens. Phi1.5's data set configuration is accessible but uses the open source model Mixtral 8x7B and is licensed under the Apache 2.0 license.
ko_text 300๋งŒ ๊ฐœ ์ด์ƒ์˜ ํŒŒ์ผ๊ณผ 25์–ต ๊ฐœ์˜ ํ† ํฐ์ด ์žˆ์Šต๋‹ˆ๋‹ค. Phi1.5์˜ ๋ฐ์ดํ„ฐ ์„ธํŠธ ๊ตฌ์„ฑ์€ ์•ก์„ธ์Šค ๊ฐ€๋Šฅํ•˜์ง€๋งŒ ์˜คํ”ˆ ์†Œ์Šค ๋ชจ๋ธ Mixtral 8x7B๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  Apache 2.0 ๋ผ์ด์„ ์Šค์— ๋”ฐ๋ผ ๋ผ์ด์„ ์Šค๋ฉ๋‹ˆ๋‹ค.
score 0.3090015909429233
"""
if __name__ == "__main__":
main()