diff options
| author | Yigit Sever | 2019-09-22 00:49:38 +0300 |
|---|---|---|
| committer | Yigit Sever | 2019-09-22 00:49:38 +0300 |
| commit | 998011d64694da562a33e02092df1a7b57e84b98 (patch) | |
| tree | 8e6bceaaa3289b091d3d32b8fa5442052ef3291f | |
| parent | 22cc3c22d317c98940e5f38715266cf9757880b2 (diff) | |
| download | Evaluating-Dictionary-Alignment-998011d64694da562a33e02092df1a7b57e84b98.tar.gz Evaluating-Dictionary-Alignment-998011d64694da562a33e02092df1a7b57e84b98.tar.bz2 Evaluating-Dictionary-Alignment-998011d64694da562a33e02092df1a7b57e84b98.zip | |
Preparing for merging two WMD scripts
* Added align() to return p_at_one, percentage for
both metrics
| -rw-r--r-- | WMD_retrieval.py | 7 | ||||
| -rw-r--r-- | Wasserstein_Distance.py | 11 |
2 files changed, 15 insertions, 3 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py index fbf1b57..f32372f 100644 --- a/WMD_retrieval.py +++ b/WMD_retrieval.py | |||
| @@ -93,9 +93,10 @@ def main(args): | |||
| 93 | 93 | ||
| 94 | clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) | 94 | clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) |
| 95 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 95 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
| 96 | dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 96 | # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
| 97 | mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) | 97 | # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) |
| 98 | percentage = p_at_one * 100 | 98 | # percentage = p_at_one * 100 |
| 99 | p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances) | ||
| 99 | 100 | ||
| 100 | if (not batch): | 101 | if (not batch): |
| 101 | print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') | 102 | print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') |
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py index 925eca3..5965ed5 100644 --- a/Wasserstein_Distance.py +++ b/Wasserstein_Distance.py | |||
| @@ -138,6 +138,17 @@ class Wasserstein_Retriever(KNeighborsClassifier): | |||
| 138 | dist = self._pairwise_wmd(X) | 138 | dist = self._pairwise_wmd(X) |
| 139 | return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) | 139 | return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) |
| 140 | 140 | ||
| 141 | def align(self, X, n_neighbors=1) | ||
| 142 | """ | ||
| 143 | Wrapper function over kneighbors to return | ||
| 144 | precision at one and percentage values | ||
| 145 | |||
| 146 | """ | ||
| 147 | dist, preds = self.kneighbors(X, n_neighbors) | ||
| 148 | mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) | ||
| 149 | percentage = p_at_one * 100 | ||
| 150 | return (p_at_one, percentage) | ||
| 151 | |||
| 141 | 152 | ||
| 142 | def load_embeddings(path, dimension=300): | 153 | def load_embeddings(path, dimension=300): |
| 143 | """ | 154 | """ |
