Spaces:
Runtime error
Runtime error
"""Chatbots using API-based services.""" | |
from __future__ import annotations | |
import os | |
import re | |
from dataclasses import dataclass | |
import config | |
class GptMtInstance: | |
"""An instance from the GPT-MT dataset. | |
Attributes: | |
data: The input sentence. | |
label: The output sentence. | |
doc_id: The document ID. | |
lang_pair: The language pair. | |
""" | |
data: str | |
label: str | |
doc_id: str | |
lang_pair: str | |
def process_data( | |
input_dir: str, | |
lang_pairs: list[str], | |
) -> list[GptMtInstance]: | |
"""Load data.""" | |
# Load the data | |
data: list[GptMtInstance] = [] | |
eval_dir = os.path.join(input_dir, "evaluation", "testset") | |
for lang_pair in lang_pairs: | |
src_lang, trg_lang = lang_pair[:2], lang_pair[2:] | |
src_file = os.path.join( | |
eval_dir, "wmt-testset", lang_pair, f"test.{src_lang}-{trg_lang}.{src_lang}" | |
) | |
trg_file = os.path.join( | |
eval_dir, "wmt-testset", lang_pair, f"test.{src_lang}-{trg_lang}.{trg_lang}" | |
) | |
doc_file = os.path.join( | |
eval_dir, | |
"wmt-testset-docids", | |
lang_pair, | |
f"test.{src_lang}-{trg_lang}.docids", | |
) | |
with open(src_file, "r") as src_in, open(trg_file, "r") as trg_in, open( | |
doc_file, "r" | |
) as doc_in: | |
for src_line, trg_line, doc_line in zip(src_in, trg_in, doc_in): | |
data.append( | |
GptMtInstance( | |
src_line.strip(), trg_line.strip(), doc_line.strip(), lang_pair | |
) | |
) | |
return data | |
def remove_leading_language(line: str) -> str: | |
"""Remove a language at the beginning of the string. | |
Some zero-shot models output the name of the language at the beginning of the | |
string. This is a manual post-processing function that removes the language name | |
(partly as an example of how you can do simple fixes to issues that come up during | |
analysis using Zeno). | |
Args: | |
line: The line to process. | |
Returns: | |
The line with the language removed. | |
""" | |
return re.sub( | |
r"^(English|Japanese|Chinese|Hausa|Icelandic|French|German|Russian|Ukranian): ", | |
"", | |
line, | |
) | |
def process_output( | |
input_dir: str, | |
lang_pairs: list[str], | |
model_preset: str, | |
) -> list[str]: | |
"""Load model outputs.""" | |
# Load the data | |
data: list[str] = [] | |
model_config = config.model_configs[model_preset] | |
model_path = model_config.path | |
system_dir = os.path.join(input_dir, "evaluation", "system-outputs", model_path) | |
for lang_pair in lang_pairs: | |
src_lang, trg_lang = lang_pair[:2], lang_pair[2:] | |
sys_file = os.path.join( | |
system_dir, lang_pair, f"test.{src_lang}-{trg_lang}.{trg_lang}" | |
) | |
with open(sys_file, "r") as sys_in: | |
for sys_line in sys_in: | |
sys_line = sys_line.strip() | |
if model_config.post_processors is not None: | |
for postprocessor in model_config.post_processors: | |
sys_line = postprocessor(sys_line) | |
data.append(sys_line) | |
return data | |