aboutsummaryrefslogtreecommitdiffstats
path: root/WMD_retrieval.py
diff options
context:
space:
mode:
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r--WMD_retrieval.py15
1 files changed, 6 insertions, 9 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
index 3328023..02f35be 100644
--- a/WMD_retrieval.py
+++ b/WMD_retrieval.py
@@ -6,7 +6,7 @@ import numpy as np
6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
7from sklearn.preprocessing import normalize 7from sklearn.preprocessing import normalize
8 8
9from Wasserstein_Distance import (Wasserstein_Retriever, 9from Wasserstein_Distance import (WassersteinRetriever,
10 clean_corpus_using_embeddings_vocabulary, 10 clean_corpus_using_embeddings_vocabulary,
11 load_embeddings) 11 load_embeddings)
12 12
@@ -101,16 +101,13 @@ def main(args):
101 101
102 for metric in runfor: 102 for metric in runfor:
103 if not batch: 103 if not batch:
104 print(f'{metric} - tfidf: {source_lang} - {target_lang}') 104 print(f'{metric}: {source_lang} - {target_lang}')
105 105
106 clf = Wasserstein_Retriever(W_embed=W_common, 106 clf = WassersteinRetriever(W_embed=W_common,
107 n_neighbors=5, 107 n_neighbors=5,
108 n_jobs=14, 108 n_jobs=14,
109 sinkhorn=(metric == 'snk')) 109 sinkhorn=(metric == 'snk'))
110 clf.fit(X_train_idf[:instances], np.ones(instances)) 110 clf.fit(X_train_idf[:instances], np.ones(instances))
111 # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
112 # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds)
113 # percentage = p_at_one * 100
114 p_at_one, percentage = clf.align(X_test_idf[:instances], 111 p_at_one, percentage = clf.align(X_test_idf[:instances],
115 n_neighbors=instances) 112 n_neighbors=instances)
116 113