aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--WMD_retrieval.py7
-rw-r--r--Wasserstein_Distance.py11
2 files changed, 15 insertions, 3 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
index fbf1b57..f32372f 100644
--- a/WMD_retrieval.py
+++ b/WMD_retrieval.py
@@ -93,9 +93,10 @@ def main(args):
93 93
94 clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) 94 clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk'))
95 clf.fit(X_train_idf[:instances], np.ones(instances)) 95 clf.fit(X_train_idf[:instances], np.ones(instances))
96 dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) 96 # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
97 mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) 97 # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds)
98 percentage = p_at_one * 100 98 # percentage = p_at_one * 100
99 p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances)
99 100
100 if (not batch): 101 if (not batch):
101 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') 102 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%')
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py
index 925eca3..5965ed5 100644
--- a/Wasserstein_Distance.py
+++ b/Wasserstein_Distance.py
@@ -138,6 +138,17 @@ class Wasserstein_Retriever(KNeighborsClassifier):
138 dist = self._pairwise_wmd(X) 138 dist = self._pairwise_wmd(X)
139 return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) 139 return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors)
140 140
141 def align(self, X, n_neighbors=1)
142 """
143 Wrapper function over kneighbors to return
144 precision at one and percentage values
145
146 """
147 dist, preds = self.kneighbors(X, n_neighbors)
148 mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds)
149 percentage = p_at_one * 100
150 return (p_at_one, percentage)
151
141 152
142def load_embeddings(path, dimension=300): 153def load_embeddings(path, dimension=300):
143 """ 154 """