aboutsummaryrefslogtreecommitdiffstats
path: root/WMD_retrieval.py
diff options
context:
space:
mode:
authorYigit Sever2019-09-22 00:22:54 +0300
committerYigit Sever2019-09-22 00:22:54 +0300
commit3c135215db79fac37ebede465db567395fa5daa5 (patch)
treecd5934c183b42f6c1d1ab71148c536c4eeed6455 /WMD_retrieval.py
parent73df12b6787304e6982e4543d63700beabc88085 (diff)
downloadEvaluating-Dictionary-Alignment-3c135215db79fac37ebede465db567395fa5daa5.tar.gz
Evaluating-Dictionary-Alignment-3c135215db79fac37ebede465db567395fa5daa5.tar.bz2
Evaluating-Dictionary-Alignment-3c135215db79fac37ebede465db567395fa5daa5.zip
Clean up WMD_retrieval
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r--WMD_retrieval.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
index 32f3b5d..fbf1b57 100644
--- a/WMD_retrieval.py
+++ b/WMD_retrieval.py
@@ -6,7 +6,6 @@ from sklearn.preprocessing import normalize
6from Wasserstein_Distance import Wasserstein_Retriever 6from Wasserstein_Distance import Wasserstein_Retriever
7from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary, mrr_precision_at_k 7from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary, mrr_precision_at_k
8import csv 8import csv
9import sys
10 9
11def main(args): 10def main(args):
12 11
@@ -95,13 +94,13 @@ def main(args):
95 clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) 94 clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk'))
96 clf.fit(X_train_idf[:instances], np.ones(instances)) 95 clf.fit(X_train_idf[:instances], np.ones(instances))
97 dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) 96 dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
98 mrr, p_at_1 = mrr_precision_at_k(list(range(len(preds))), preds) 97 mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds)
99 percentage = p_at_1 * 100 98 percentage = p_at_one * 100
100 99
101 if (not batch): 100 if (not batch):
102 print(f'MRR: {mrr} | Precision @ 1: {p_at_1}') 101 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%')
103 else: 102 else:
104 fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{mrr}', f'{p_at_1}', f'{percentage}'] 103 fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{p_at_one}', f'{percentage}']
105 with open(f'{metric}_retrieval_result.csv', 'a') as f: 104 with open(f'{metric}_retrieval_result.csv', 'a') as f:
106 writer = csv.writer(f) 105 writer = csv.writer(f)
107 writer.writerow(fields) 106 writer.writerow(fields)