aboutsummaryrefslogtreecommitdiffstats
path: root/Wasserstein_Distance.py
diff options
context:
space:
mode:
Diffstat (limited to 'Wasserstein_Distance.py')
-rw-r--r--Wasserstein_Distance.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py
index 925eca3..5965ed5 100644
--- a/Wasserstein_Distance.py
+++ b/Wasserstein_Distance.py
@@ -138,6 +138,17 @@ class Wasserstein_Retriever(KNeighborsClassifier):
138 dist = self._pairwise_wmd(X) 138 dist = self._pairwise_wmd(X)
139 return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) 139 return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors)
140 140
141 def align(self, X, n_neighbors=1)
142 """
143 Wrapper function over kneighbors to return
144 precision at one and percentage values
145
146 """
147 dist, preds = self.kneighbors(X, n_neighbors)
148 mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds)
149 percentage = p_at_one * 100
150 return (p_at_one, percentage)
151
141 152
142def load_embeddings(path, dimension=300): 153def load_embeddings(path, dimension=300):
143 """ 154 """