diff options
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r-- | WMD_retrieval.py | 15 |
1 files changed, 6 insertions, 9 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py index 3328023..02f35be 100644 --- a/WMD_retrieval.py +++ b/WMD_retrieval.py | |||
@@ -6,7 +6,7 @@ import numpy as np | |||
6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | 6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer |
7 | from sklearn.preprocessing import normalize | 7 | from sklearn.preprocessing import normalize |
8 | 8 | ||
9 | from Wasserstein_Distance import (Wasserstein_Retriever, | 9 | from Wasserstein_Distance import (WassersteinRetriever, |
10 | clean_corpus_using_embeddings_vocabulary, | 10 | clean_corpus_using_embeddings_vocabulary, |
11 | load_embeddings) | 11 | load_embeddings) |
12 | 12 | ||
@@ -101,16 +101,13 @@ def main(args): | |||
101 | 101 | ||
102 | for metric in runfor: | 102 | for metric in runfor: |
103 | if not batch: | 103 | if not batch: |
104 | print(f'{metric} - tfidf: {source_lang} - {target_lang}') | 104 | print(f'{metric}: {source_lang} - {target_lang}') |
105 | 105 | ||
106 | clf = Wasserstein_Retriever(W_embed=W_common, | 106 | clf = WassersteinRetriever(W_embed=W_common, |
107 | n_neighbors=5, | 107 | n_neighbors=5, |
108 | n_jobs=14, | 108 | n_jobs=14, |
109 | sinkhorn=(metric == 'snk')) | 109 | sinkhorn=(metric == 'snk')) |
110 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 110 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
111 | # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | ||
112 | # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) | ||
113 | # percentage = p_at_one * 100 | ||
114 | p_at_one, percentage = clf.align(X_test_idf[:instances], | 111 | p_at_one, percentage = clf.align(X_test_idf[:instances], |
115 | n_neighbors=instances) | 112 | n_neighbors=instances) |
116 | 113 | ||