From 998011d64694da562a33e02092df1a7b57e84b98 Mon Sep 17 00:00:00 2001 From: Yigit Sever Date: Sun, 22 Sep 2019 00:49:38 +0300 Subject: Preparing for merging two WMD scripts * Added align() to return p_at_one, percentage for both metrics --- WMD_retrieval.py | 7 ++++--- Wasserstein_Distance.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/WMD_retrieval.py b/WMD_retrieval.py index fbf1b57..f32372f 100644 --- a/WMD_retrieval.py +++ b/WMD_retrieval.py @@ -93,9 +93,10 @@ def main(args): clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) clf.fit(X_train_idf[:instances], np.ones(instances)) - dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) - mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) - percentage = p_at_one * 100 + # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) + # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) + # percentage = p_at_one * 100 + p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances) if (not batch): print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py index 925eca3..5965ed5 100644 --- a/Wasserstein_Distance.py +++ b/Wasserstein_Distance.py @@ -138,6 +138,17 @@ class Wasserstein_Retriever(KNeighborsClassifier): dist = self._pairwise_wmd(X) return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) + def align(self, X, n_neighbors=1) + """ + Wrapper function over kneighbors to return + precision at one and percentage values + + """ + dist, preds = self.kneighbors(X, n_neighbors) + mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) + percentage = p_at_one * 100 + return (p_at_one, percentage) + def load_embeddings(path, dimension=300): """ -- cgit v1.2.3-70-g09d2