Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from __future__ import print_function | |
import pickle | |
import json | |
import csv | |
import sys | |
from io import open | |
# Allow us to import the torchmoji directory | |
from os.path import dirname, abspath | |
sys.path.insert(0, dirname(dirname(abspath(__file__)))) | |
from torchmoji.sentence_tokenizer import SentenceTokenizer, coverage | |
try: | |
unicode # Python 2 | |
except NameError: | |
unicode = str # Python 3 | |
IS_PYTHON2 = int(sys.version[0]) == 2 | |
OUTPUT_PATH = 'coverage.csv' | |
DATASET_PATHS = [ | |
'../data/Olympic/raw.pickle', | |
'../data/PsychExp/raw.pickle', | |
'../data/SCv1/raw.pickle', | |
'../data/SCv2-GEN/raw.pickle', | |
'../data/SE0714/raw.pickle', | |
#'../data/SE1604/raw.pickle', # Excluded due to Twitter's ToS | |
'../data/SS-Twitter/raw.pickle', | |
'../data/SS-Youtube/raw.pickle', | |
] | |
with open('../model/vocabulary.json', 'r') as f: | |
vocab = json.load(f) | |
results = [] | |
for p in DATASET_PATHS: | |
coverage_result = [p] | |
print('Calculating coverage for {}'.format(p)) | |
with open(p, 'rb') as f: | |
if IS_PYTHON2: | |
s = pickle.load(f) | |
else: | |
s = pickle.load(f, fix_imports=True) | |
# Decode data | |
try: | |
s['texts'] = [unicode(x) for x in s['texts']] | |
except UnicodeDecodeError: | |
s['texts'] = [x.decode('utf-8') for x in s['texts']] | |
# Own | |
st = SentenceTokenizer({}, 30) | |
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'], | |
[s['train_ind'], | |
s['val_ind'], | |
s['test_ind']], | |
extend_with=10000) | |
coverage_result.append(coverage(tests[2])) | |
# Last | |
st = SentenceTokenizer(vocab, 30) | |
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'], | |
[s['train_ind'], | |
s['val_ind'], | |
s['test_ind']], | |
extend_with=0) | |
coverage_result.append(coverage(tests[2])) | |
# Full | |
st = SentenceTokenizer(vocab, 30) | |
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'], | |
[s['train_ind'], | |
s['val_ind'], | |
s['test_ind']], | |
extend_with=10000) | |
coverage_result.append(coverage(tests[2])) | |
results.append(coverage_result) | |
with open(OUTPUT_PATH, 'wb') as csvfile: | |
writer = csv.writer(csvfile, delimiter='\t', lineterminator='\n') | |
writer.writerow(['Dataset', 'Own', 'Last', 'Full']) | |
for i, row in enumerate(results): | |
try: | |
writer.writerow(row) | |
except: | |
print("Exception at row {}!".format(i)) | |
print('Saved to {}'.format(OUTPUT_PATH)) | |