aboutsummaryrefslogtreecommitdiffstats
path: root/Wasserstein_Distance.py
diff options
context:
space:
mode:
authorYigit Sever2019-09-22 00:49:38 +0300
committerYigit Sever2019-09-22 00:49:38 +0300
commit998011d64694da562a33e02092df1a7b57e84b98 (patch)
tree8e6bceaaa3289b091d3d32b8fa5442052ef3291f /Wasserstein_Distance.py
parent22cc3c22d317c98940e5f38715266cf9757880b2 (diff)
downloadEvaluating-Dictionary-Alignment-998011d64694da562a33e02092df1a7b57e84b98.tar.gz
Evaluating-Dictionary-Alignment-998011d64694da562a33e02092df1a7b57e84b98.tar.bz2
Evaluating-Dictionary-Alignment-998011d64694da562a33e02092df1a7b57e84b98.zip
Preparing for merging two WMD scripts
* Added align() to return p_at_one, percentage for both metrics
diffstat (limited to 'Wasserstein_Distance.py')
-rw-r--r--Wasserstein_Distance.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py
index 925eca3..5965ed5 100644
--- a/Wasserstein_Distance.py
+++ b/Wasserstein_Distance.py
@@ -138,6 +138,17 @@ class Wasserstein_Retriever(KNeighborsClassifier):
138 dist = self._pairwise_wmd(X) 138 dist = self._pairwise_wmd(X)
139 return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) 139 return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors)
140 140
141 def align(self, X, n_neighbors=1)
142 """
143 Wrapper function over kneighbors to return
144 precision at one and percentage values
145
146 """
147 dist, preds = self.kneighbors(X, n_neighbors)
148 mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds)
149 percentage = p_at_one * 100
150 return (p_at_one, percentage)
151
141 152
142def load_embeddings(path, dimension=300): 153def load_embeddings(path, dimension=300):
143 """ 154 """