diff options
-rw-r--r-- | WMD_retrieval.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py index 32f3b5d..fbf1b57 100644 --- a/WMD_retrieval.py +++ b/WMD_retrieval.py | |||
@@ -6,7 +6,6 @@ from sklearn.preprocessing import normalize | |||
6 | from Wasserstein_Distance import Wasserstein_Retriever | 6 | from Wasserstein_Distance import Wasserstein_Retriever |
7 | from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary, mrr_precision_at_k | 7 | from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary, mrr_precision_at_k |
8 | import csv | 8 | import csv |
9 | import sys | ||
10 | 9 | ||
11 | def main(args): | 10 | def main(args): |
12 | 11 | ||
@@ -95,13 +94,13 @@ def main(args): | |||
95 | clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) | 94 | clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) |
96 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 95 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
97 | dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 96 | dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
98 | mrr, p_at_1 = mrr_precision_at_k(list(range(len(preds))), preds) | 97 | mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) |
99 | percentage = p_at_1 * 100 | 98 | percentage = p_at_one * 100 |
100 | 99 | ||
101 | if (not batch): | 100 | if (not batch): |
102 | print(f'MRR: {mrr} | Precision @ 1: {p_at_1}') | 101 | print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') |
103 | else: | 102 | else: |
104 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{mrr}', f'{p_at_1}', f'{percentage}'] | 103 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{p_at_one}', f'{percentage}'] |
105 | with open(f'{metric}_retrieval_result.csv', 'a') as f: | 104 | with open(f'{metric}_retrieval_result.csv', 'a') as f: |
106 | writer = csv.writer(f) | 105 | writer = csv.writer(f) |
107 | writer.writerow(fields) | 106 | writer.writerow(fields) |