From 22cc3c22d317c98940e5f38715266cf9757880b2 Mon Sep 17 00:00:00 2001 From: Yigit Sever Date: Sun, 22 Sep 2019 00:26:08 +0300 Subject: Roll the methods to a for, like retrieval --- WMD_matching.py | 48 +++++++++++++++++------------------------------- 1 file changed, 17 insertions(+), 31 deletions(-) (limited to 'WMD_matching.py') diff --git a/WMD_matching.py b/WMD_matching.py index 8a97389..59b64f9 100644 --- a/WMD_matching.py +++ b/WMD_matching.py @@ -5,6 +5,7 @@ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer from sklearn.preprocessing import normalize from Wasserstein_Distance import Wasserstein_Matcher from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary +import csv def main(args): @@ -22,6 +23,12 @@ def main(args): batch = args.batch mode = args.mode + runfor = list() + + if (mode == 'all'): + runfor.extend(['wmd','snk']) + else: + runfor.append(mode) defs_source = [line.rstrip('\n') for line in open(source_defs_filename, encoding='utf8')] defs_target = [line.rstrip('\n') for line in open(target_defs_filename, encoding='utf8')] @@ -80,47 +87,25 @@ def main(args): X_train_tf = vect_tf.transform(clean_src_corpus) X_test_tf = vect_tf.transform(clean_target_corpus) - if (mode == 'wmd' or mode == 'all'): + for metric in runfor: if (not batch): - print(f'WMD - tfidf: {source_lang} - {target_lang}') + print(f'{metric}: {source_lang} - {target_lang}') - clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14) + clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) clf.fit(X_train_idf[:instances], np.ones(instances)) row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) result = zip(row_ind, col_ind) - hit_one = len([x for x,y in result if x == y]) - percentage = hit_one / instances * 100 + p_at_one = len([x for x,y in result if x == y]) + percentage = p_at_one / instances * 100 if (not batch): - print(f'{hit_one} definitions have been mapped correctly, {percentage}%') - - if (batch): - import csv - fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] - with open('wmd_matching_results.csv', 'a') as f: + print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') + else: + fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{p_at_one}', f'{percentage}'] + with open(f'{metric}_matching_results.csv', 'a') as f: writer = csv.writer(f) writer.writerow(fields) - if (mode == 'snk' or mode == 'all'): - if (not batch): - print(f'Sinkhorn - tfidf: {source_lang} - {target_lang}') - - clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True) - clf.fit(X_train_idf[:instances], np.ones(instances)) - row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) - - result = zip(row_ind, col_ind) - hit_one = len([x for x,y in result if x == y]) - percentage = hit_one / instances * 100 - - if (not batch): - print(f'{hit_one} definitions have been mapped correctly, {percentage}%') - - if (batch): - fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] - with open('sinkhorn_matching_result.csv', 'a') as f: - writer = csv.writer(f) - writer.writerow(fields) if __name__ == "__main__": @@ -134,6 +119,7 @@ if __name__ == "__main__": parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) + args = parser.parse_args() main(args) -- cgit v1.2.3-70-g09d2