diff options
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r-- | WMD_retrieval.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py index fbf1b57..f32372f 100644 --- a/WMD_retrieval.py +++ b/WMD_retrieval.py | |||
@@ -93,9 +93,10 @@ def main(args): | |||
93 | 93 | ||
94 | clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) | 94 | clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) |
95 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 95 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
96 | dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 96 | # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
97 | mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) | 97 | # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) |
98 | percentage = p_at_one * 100 | 98 | # percentage = p_at_one * 100 |
99 | p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances) | ||
99 | 100 | ||
100 | if (not batch): | 101 | if (not batch): |
101 | print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') | 102 | print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') |