#!/usr/bin/env python3 import argparse import json import os import transformers if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("input_path", type=str, help="Input directory") parser.add_argument("output_path", type=str, help="Output directory") args = parser.parse_args() # Fix vocab.json def fix_vocab(vocab): mask_id = 51960 unused = mask_id + 1 remapped = [] fixed_vocab = {} for key, value in vocab.items(): if value == 3 and key != "[UNK]": if key == "ĠĊ": fixed_vocab[key] = mask_id - 1 else: remapped.append((key, unused)) unused += 1 else: fixed_vocab[key] = value for key, value in remapped: fixed_vocab[key] = value return fixed_vocab with open(os.path.join(args.input_path, "vocab.json"), "r", encoding="utf-8") as vocab_file: vocab = json.load(vocab_file) fixed_vocab = fix_vocab(vocab) with open(os.path.join(args.output_path, "vocab.json"), "w", encoding="utf-8") as vocab_file: json.dump(fixed_vocab, vocab_file, ensure_ascii=False, indent=None) print(file=vocab_file) # Regenerate tokenizer.json tokenizer = transformers.AutoTokenizer.from_pretrained(args.output_path) tokenizer._tokenizer.save(os.path.join(args.output_path, "tokenizer.json"))