From 3c135215db79fac37ebede465db567395fa5daa5 Mon Sep 17 00:00:00 2001 From: Yigit Sever Date: Sun, 22 Sep 2019 00:22:54 +0300 Subject: Clean up WMD_retrieval --- WMD_retrieval.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'WMD_retrieval.py') diff --git a/WMD_retrieval.py b/WMD_retrieval.py index 32f3b5d..fbf1b57 100644 --- a/WMD_retrieval.py +++ b/WMD_retrieval.py @@ -6,7 +6,6 @@ from sklearn.preprocessing import normalize from Wasserstein_Distance import Wasserstein_Retriever from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary, mrr_precision_at_k import csv -import sys def main(args): @@ -95,13 +94,13 @@ def main(args): clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) clf.fit(X_train_idf[:instances], np.ones(instances)) dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) - mrr, p_at_1 = mrr_precision_at_k(list(range(len(preds))), preds) - percentage = p_at_1 * 100 + mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) + percentage = p_at_one * 100 if (not batch): - print(f'MRR: {mrr} | Precision @ 1: {p_at_1}') + print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') else: - fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{mrr}', f'{p_at_1}', f'{percentage}'] + fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{p_at_one}', f'{percentage}'] with open(f'{metric}_retrieval_result.csv', 'a') as f: writer = csv.writer(f) writer.writerow(fields) -- cgit v1.2.3-70-g09d2