Spaces:
Build error
Build error
import json | |
from tqdm import tqdm | |
import re | |
import fire | |
def tokenize_caption(input_json: str, | |
keep_punctuation: bool = False, | |
host_address: str = None, | |
character_level: bool = False, | |
zh: bool = True, | |
output_json: str = None): | |
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold | |
Args: | |
input_json(string): Preprossessed json file. Structure like this: | |
{ | |
'audios': [ | |
{ | |
'audio_id': 'xxx', | |
'captions': [ | |
{ | |
'caption': 'xxx', | |
'cap_id': 'xxx' | |
} | |
] | |
}, | |
... | |
] | |
} | |
threshold (int): Threshold to drop all words with counts < threshold | |
keep_punctuation (bool): Includes or excludes punctuation. | |
Returns: | |
vocab (Vocab): Object with the processed vocabulary | |
""" | |
data = json.load(open(input_json, "r"))["audios"] | |
if zh: | |
from nltk.parse.corenlp import CoreNLPParser | |
from zhon.hanzi import punctuation | |
parser = CoreNLPParser(host_address) | |
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): | |
for cap_idx in range(len(data[audio_idx]["captions"])): | |
caption = data[audio_idx]["captions"][cap_idx]["caption"] | |
# Remove all punctuations | |
if not keep_punctuation: | |
caption = re.sub("[{}]".format(punctuation), "", caption) | |
if character_level: | |
tokens = list(caption) | |
else: | |
tokens = list(parser.tokenize(caption)) | |
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens) | |
else: | |
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer | |
captions = {} | |
for audio_idx in range(len(data)): | |
audio_id = data[audio_idx]["audio_id"] | |
captions[audio_id] = [] | |
for cap_idx in range(len(data[audio_idx]["captions"])): | |
caption = data[audio_idx]["captions"][cap_idx]["caption"] | |
captions[audio_id].append({ | |
"audio_id": audio_id, | |
"id": cap_idx, | |
"caption": caption | |
}) | |
tokenizer = PTBTokenizer() | |
captions = tokenizer.tokenize(captions) | |
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True): | |
audio_id = data[audio_idx]["audio_id"] | |
for cap_idx in range(len(data[audio_idx]["captions"])): | |
tokens = captions[audio_id][cap_idx] | |
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens | |
if output_json: | |
json.dump( | |
{ "audios": data }, open(output_json, "w"), | |
indent=4, ensure_ascii=not zh) | |
else: | |
json.dump( | |
{ "audios": data }, open(input_json, "w"), | |
indent=4, ensure_ascii=not zh) | |
if __name__ == "__main__": | |
fire.Fire(tokenize_caption) | |