aboutsummaryrefslogtreecommitdiffstats
path: root/WMD_retrieval.py
diff options
context:
space:
mode:
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r--WMD_retrieval.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
index fbf1b57..f32372f 100644
--- a/WMD_retrieval.py
+++ b/WMD_retrieval.py
@@ -93,9 +93,10 @@ def main(args):
93 93
94 clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) 94 clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk'))
95 clf.fit(X_train_idf[:instances], np.ones(instances)) 95 clf.fit(X_train_idf[:instances], np.ones(instances))
96 dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) 96 # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
97 mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) 97 # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds)
98 percentage = p_at_one * 100 98 # percentage = p_at_one * 100
99 p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances)
99 100
100 if (not batch): 101 if (not batch):
101 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') 102 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%')