|
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList |
|
import torch |
|
from utils.simple_bleu import simple_score |
|
import torch |
|
|
|
repo = "squarelike/Gugugo-koen-7B-V1.1" |
|
model = AutoModelForCausalLM.from_pretrained(repo, torch_dtype=torch.bfloat16, device_map='auto') |
|
tokenizer = AutoTokenizer.from_pretrained(repo) |
|
|
|
|
|
class StoppingCriteriaSub(StoppingCriteria): |
|
def __init__(self, stops=[], encounters=1): |
|
super().__init__() |
|
self.stops = [stop for stop in stops] |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
|
for stop in self.stops: |
|
if torch.all((stop == input_ids[0][-len(stop):])).item(): |
|
return True |
|
|
|
return False |
|
|
|
|
|
stop_words_ids = torch.tensor( |
|
[[829, 45107, 29958], [1533, 45107, 29958], [829, 45107, 29958], [21106, 45107, 29958]]).to("cuda") |
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
|
|
|
def generate(prompt): |
|
gened = model.generate( |
|
**tokenizer( |
|
prompt, |
|
return_tensors='pt', |
|
return_token_type_ids=False |
|
).to("cuda"), |
|
max_new_tokens=2048, |
|
temperature=0.1, |
|
num_beams=5, |
|
stopping_criteria=stopping_criteria |
|
) |
|
result = tokenizer.decode(gened[0][1:]).replace(prompt + " ", "").replace("</๋>", "") |
|
result = result.replace('</s>', '') |
|
result = result.replace('### ํ๊ตญ์ด: ', '') |
|
result = result.replace('### ์์ด: ', '') |
|
return result |
|
|
|
|
|
def translate_ko2en(text): |
|
prompt = f"### ํ๊ตญ์ด: {text}</๋>\n### ์์ด:" |
|
return generate(prompt) |
|
|
|
|
|
def translate_en2ko(text): |
|
prompt = f"### ์์ด: {text}</๋>\n### ํ๊ตญ์ด:" |
|
return generate(prompt) |
|
|
|
|
|
def main(): |
|
while True: |
|
text = input('>') |
|
en_text = translate_ko2en(text) |
|
ko_text = translate_en2ko(en_text) |
|
print('en_text', en_text) |
|
print('ko_text', ko_text) |
|
print('score', simple_score(text, ko_text)) |
|
""" |
|
>>? 3์ฒ๋ง ๊ฐ๊ฐ ๋๋ ํ์ผ๊ณผ 250์ต ๊ฐ์ ํ ํฐ์ด ์์ต๋๋ค. Phi1.5์ ๋ฐ์ดํฐ ์ธํธ ๊ตฌ์ฑ์ ์ ๊ทผํ์ง๋ง ์คํ ์์ค ๋ชจ๋ธ์ธ Mixtral 8x7B๋ฅผ ์ฌ์ฉํ๊ณ Apache2.0 ๋ผ์ด์ ์ค์ ๋ฐ๋ผ ๋ผ์ด์ ์ค๊ฐ ๋ถ์ฌ๋ฉ๋๋ค. |
|
en_text We have 30 million files and 2.5 billion tokens. We approach Phi1.5's dataset composition, but we use the open-source model, Mixtral 8x7B, and we are licensed according to the Apache2.0 license. |
|
ko_text 3,000๋ง ๊ฐ์ ํ์ผ๊ณผ 250์ต ๊ฐ์ ํ ํฐ์ด ์์ต๋๋ค. Phi1.5์ ๋ฐ์ดํฐ ์งํฉ์ ์ ๊ทผํ์ง๋ง ์คํ ์์ค ๋ชจ๋ธ์ธ Mixtral 8x7B๋ฅผ ์ฌ์ฉํ๊ณ Apache2.0 ๋ผ์ด์ ์ค์ ๋ฐ๋ผ ๋ผ์ด์ ์ค๋ฅผ ๋ฐ์์ต๋๋ค. |
|
score 0.6154733407407874 |
|
""" |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|