aboutsummaryrefslogtreecommitdiffstats
path: root/WMD_retrieval.py
diff options
context:
space:
mode:
authorYigit Sever2019-09-22 02:02:35 +0300
committerYigit Sever2019-09-22 02:02:35 +0300
commit40d725329a1fb3699dd3b2c870fc8213db9d75eb (patch)
tree826c039fe8b95b87c78490d6809d20e3bb61322f /WMD_retrieval.py
parent2936635892e17031c37facfd2115e8cfd6633222 (diff)
downloadEvaluating-Dictionary-Alignment-40d725329a1fb3699dd3b2c870fc8213db9d75eb.tar.gz
Evaluating-Dictionary-Alignment-40d725329a1fb3699dd3b2c870fc8213db9d75eb.tar.bz2
Evaluating-Dictionary-Alignment-40d725329a1fb3699dd3b2c870fc8213db9d75eb.zip
Unified WMD/SNK matching & retrieval
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r--WMD_retrieval.py15
1 files changed, 6 insertions, 9 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
index 3328023..02f35be 100644
--- a/WMD_retrieval.py
+++ b/WMD_retrieval.py
@@ -6,7 +6,7 @@ import numpy as np
6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
7from sklearn.preprocessing import normalize 7from sklearn.preprocessing import normalize
8 8
9from Wasserstein_Distance import (Wasserstein_Retriever, 9from Wasserstein_Distance import (WassersteinRetriever,
10 clean_corpus_using_embeddings_vocabulary, 10 clean_corpus_using_embeddings_vocabulary,
11 load_embeddings) 11 load_embeddings)
12 12
@@ -101,16 +101,13 @@ def main(args):
101 101
102 for metric in runfor: 102 for metric in runfor:
103 if not batch: 103 if not batch:
104 print(f'{metric} - tfidf: {source_lang} - {target_lang}') 104 print(f'{metric}: {source_lang} - {target_lang}')
105 105
106 clf = Wasserstein_Retriever(W_embed=W_common, 106 clf = WassersteinRetriever(W_embed=W_common,
107 n_neighbors=5, 107 n_neighbors=5,
108 n_jobs=14, 108 n_jobs=14,
109 sinkhorn=(metric == 'snk')) 109 sinkhorn=(metric == 'snk'))
110 clf.fit(X_train_idf[:instances], np.ones(instances)) 110 clf.fit(X_train_idf[:instances], np.ones(instances))
111 # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
112 # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds)
113 # percentage = p_at_one * 100
114 p_at_one, percentage = clf.align(X_test_idf[:instances], 111 p_at_one, percentage = clf.align(X_test_idf[:instances],
115 n_neighbors=instances) 112 n_neighbors=instances)
116 113