ko-translation-leaderbaord / translation_self.py
davidkim205's picture
Upload folder using huggingface_hub
577164e verified
import argparse
import json
import os
from tqdm import tqdm
from utils.simple_bleu import simple_score
def load_json(filename):
json_data = []
with open(filename, "r", encoding="utf-8") as f:
if os.path.splitext(filename)[1] != ".jsonl":
json_data = json.load(f)
else:
for line in f:
json_data.append(json.loads(line))
return json_data
def save_json(json_data, filename, option="a"):
directory, _ = os.path.split(filename)
if not os.path.exists(directory):
os.makedirs(directory)
filename = filename.replace(" ", "_")
with open(filename, option, encoding="utf-8") as f:
if not filename.endswith(".jsonl"):
json.dump(json_data, f, ensure_ascii=False, indent=4)
else:
for data in json_data:
json.dump(data, f, ensure_ascii=False)
f.write("\n")
def main():
parser = argparse.ArgumentParser("argument")
parser.add_argument(
"--input_file",
default="./data/komt-1810k-test.jsonl",
type=str,
help="input_file",
)
parser.add_argument(
"--model_path",
default=None,
type=str,
help="model path",
)
parser.add_argument("--output", default="", type=str, help="model path")
parser.add_argument(
"--model", default="davidkim205/iris-7b", type=str, help="model"
)
args = parser.parse_args()
json_data = load_json(args.input_file)
if args.model == "squarelike/Gugugo-koen-7B-V1.1":
from models.gugugo import translate_en2ko, translate_ko2en
elif args.model == "jbochi/madlad400-10b-mt":
from models.madlad400 import translate_ko2en, translate_en2ko
elif args.model == "facebook/mbart-large-50-many-to-many-mmt":
from models.mbart50 import translate_en2ko, translate_ko2en
elif args.model == "facebook/nllb-200-distilled-1.3B":
from models.nllb200 import translate_ko2en, translate_en2ko
elif args.model == "Unbabel/TowerInstruct-7B-v0.1":
from models.TowerInstruct import translate_ko2en, translate_en2ko
elif args.model == "maywell/Synatra-7B-v0.3-Translation":
from models.synatra import translate_ko2en, translate_en2ko
elif args.model == "davidkim205/iris-7b":
from models.iris_7b import translate_ko2en, translate_en2ko
if args.model_path:
from models.iris_7b import load_model
load_model(args.model_path)
for index, data in tqdm(enumerate(json_data)):
chat = data["conversations"]
src = data["src"]
input = chat[0]["value"]
def clean_text(text):
if chat[0]["value"].find("ν•œκΈ€λ‘œ λ²ˆμ—­ν•˜μ„Έμš”.") != -1:
cur_lang = "en"
else:
cur_lang = "ko"
text = text.split("λ²ˆμ—­ν•˜μ„Έμš”.\n", 1)[-1]
return text, cur_lang
input, cur_lang = clean_text(input)
def do_translation(text, cur_lang):
trans = ""
try:
if cur_lang == "en":
trans = translate_en2ko(text)
else:
trans = translate_ko2en(text)
except Exception as e:
trans = ""
return trans
generation1 = do_translation(input, cur_lang)
next_lang = "ko" if cur_lang == "en" else "en"
generation2 = do_translation(generation1, next_lang)
bleu = simple_score(input, generation2)
bleu = round(bleu, 3)
result = {
"index": index,
"reference": input,
"generation": generation2,
"generation1": generation1,
"bleu": bleu,
"lang": cur_lang,
"model": args.model,
"src": src,
"conversations": chat,
}
print(json.dumps(result, ensure_ascii=False, indent=2))
if args.output:
output = args.output
else:
if args.model_path:
filename = args.model.split("/")[-1]
model_num = args.model_path.split("/")[-1]
output = f"results_self/{filename}-{model_num}-result.jsonl"
else:
filename = args.model.split("/")[-1]
output = f"results_self/{filename}-result.jsonl"
save_json([result], output)
if __name__ == "__main__":
main()