From 40d725329a1fb3699dd3b2c870fc8213db9d75eb Mon Sep 17 00:00:00 2001 From: Yigit Sever Date: Sun, 22 Sep 2019 02:02:35 +0300 Subject: Unified WMD/SNK matching & retrieval --- WMD_matching.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) (limited to 'WMD_matching.py') diff --git a/WMD_matching.py b/WMD_matching.py index ea496b8..2755d15 100644 --- a/WMD_matching.py +++ b/WMD_matching.py @@ -6,7 +6,7 @@ import numpy as np from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer from sklearn.preprocessing import normalize -from Wasserstein_Distance import (Wasserstein_Matcher, +from Wasserstein_Distance import (WassersteinMatcher, clean_corpus_using_embeddings_vocabulary, load_embeddings) @@ -103,16 +103,13 @@ def main(args): if not batch: print(f'{metric}: {source_lang} - {target_lang}') - clf = Wasserstein_Matcher(W_embed=W_common, - n_neighbors=5, - n_jobs=14, - sinkhorn=(metric == 'snk')) + clf = WassersteinMatcher(W_embed=W_common, + n_neighbors=5, + n_jobs=14, + sinkhorn=(metric == 'snk')) clf.fit(X_train_idf[:instances], np.ones(instances)) - row_ind, col_ind, _ = clf.kneighbors(X_test_idf[:instances], - n_neighbors=instances) - result = zip(row_ind, col_ind) - p_at_one = len([x for x, y in result if x == y]) - percentage = p_at_one / instances * 100 + p_at_one, percentage = clf.align(X_test_idf[:instances], + n_neighbors=instances) if not batch: print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') -- cgit v1.2.3-70-g09d2