diff options
Diffstat (limited to 'Wasserstein_Distance.py')
| -rw-r--r-- | Wasserstein_Distance.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py index 925eca3..5965ed5 100644 --- a/Wasserstein_Distance.py +++ b/Wasserstein_Distance.py | |||
| @@ -138,6 +138,17 @@ class Wasserstein_Retriever(KNeighborsClassifier): | |||
| 138 | dist = self._pairwise_wmd(X) | 138 | dist = self._pairwise_wmd(X) |
| 139 | return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) | 139 | return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) |
| 140 | 140 | ||
| 141 | def align(self, X, n_neighbors=1) | ||
| 142 | """ | ||
| 143 | Wrapper function over kneighbors to return | ||
| 144 | precision at one and percentage values | ||
| 145 | |||
| 146 | """ | ||
| 147 | dist, preds = self.kneighbors(X, n_neighbors) | ||
| 148 | mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) | ||
| 149 | percentage = p_at_one * 100 | ||
| 150 | return (p_at_one, percentage) | ||
| 151 | |||
| 141 | 152 | ||
| 142 | def load_embeddings(path, dimension=300): | 153 | def load_embeddings(path, dimension=300): |
| 143 | """ | 154 | """ |
