|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
from utils.simple_bleu import simple_score |
|
import torch |
|
|
|
model_name = 'jbochi/madlad400-10b-mt' |
|
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") |
|
tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
|
|
|
|
def translate_ko2en(text): |
|
text = f"<2en> {text}" |
|
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device) |
|
outputs = model.generate(input_ids=input_ids, max_new_tokens=2048) |
|
|
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return result |
|
|
|
|
|
def translate_en2ko(text): |
|
text = f"<2ko> {text}" |
|
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device) |
|
outputs = model.generate(input_ids=input_ids, max_new_tokens=2048) |
|
|
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return result |
|
|
|
|
|
def main(): |
|
while True: |
|
text = input('>') |
|
en_text = translate_ko2en(text) |
|
ko_text = translate_en2ko(en_text) |
|
print('en_text', en_text) |
|
print('ko_text', ko_text) |
|
print('score', simple_score(text, ko_text)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|