From 4f399e8898afd937aa5e1f3cf73dd97dc6d130fd Mon Sep 17 00:00:00 2001 From: Yigit Sever Date: Fri, 20 Sep 2019 22:01:25 +0300 Subject: Run ready WMD_Matcher --- WMD_matching.py | 52 ++++++++++++++++++++++++++-------------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/WMD_matching.py b/WMD_matching.py index 3b8b1a9..0b81696 100644 --- a/WMD_matching.py +++ b/WMD_matching.py @@ -7,22 +7,6 @@ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer from sklearn.preprocessing import normalize from Wass_Matcher import Wasserstein_Matcher -if __name__ == "__main__": - - parser = argparse.ArgumentParser(description='matching using wmd and wasserstein distance') - parser.add_argument('source_lang', help='source language short name') - parser.add_argument('target_lang', help='target language short name') - parser.add_argument('source_vector', help='path of the source vector') - parser.add_argument('target_vector', help='path of the target vector') - parser.add_argument('source_defs', help='path of the source definitions') - parser.add_argument('target_defs', help='path of the target definitions') - parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') - parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') - parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) - args = parser.parse_args() - - main(args) - def load_embeddings(path, dimension=300): """ Loads the embeddings from a word2vec formatted file. @@ -88,6 +72,7 @@ def mrr_precision_at_k(golden, preds, k_list=[1,]): def main(args): + numpy.seterr(divide='ignore') # POT has issues with divide by zero errors source_lang = args.source_lang target_lang = args.target_lang @@ -162,7 +147,7 @@ def main(args): if (not batch): print(f'WMD - tfidf: {source_lang} - {target_lang}') - clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14) + clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14) clf.fit(X_train_idf[:instances], np.ones(instances)) row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) result = zip(row_ind, col_ind) @@ -183,20 +168,35 @@ def main(args): if (not batch): print(f'Sinkhorn - tfidf: {source_lang} - {target_lang}') - clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True) + clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True) clf.fit(X_train_idf[:instances], np.ones(instances)) row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) result = zip(row_ind, col_ind) hit_one = len([x for x,y in result if x == y]) + percentage = hit_one / instances * 100 - if (not batch): - print(f'{hit_one} definitions have been mapped correctly') + if (not batch): + print(f'{hit_one} definitions have been mapped correctly, {percentage}%') + if (batch): + fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] + with open('sinkhorn_matching_result.csv', 'a') as f: + writer = csv.writer(f) + writer.writerow(fields) - if (batch): - percentage = hit_one / instances * 100 - fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] - with open('sinkhorn_matching_result.csv', 'a') as f: - writer = csv.writer(f) - writer.writerow(fields) +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='matching using wmd and wasserstein distance') + parser.add_argument('source_lang', help='source language short name') + parser.add_argument('target_lang', help='target language short name') + parser.add_argument('source_vector', help='path of the source vector') + parser.add_argument('target_vector', help='path of the target vector') + parser.add_argument('source_defs', help='path of the source definitions') + parser.add_argument('target_defs', help='path of the target definitions') + parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') + parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') + parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) + args = parser.parse_args() + + main(args) -- cgit v1.2.3-70-g09d2