diff options
Diffstat (limited to 'WMD_matching.py')
-rw-r--r-- | WMD_matching.py | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/WMD_matching.py b/WMD_matching.py index ea496b8..2755d15 100644 --- a/WMD_matching.py +++ b/WMD_matching.py | |||
@@ -6,7 +6,7 @@ import numpy as np | |||
6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | 6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer |
7 | from sklearn.preprocessing import normalize | 7 | from sklearn.preprocessing import normalize |
8 | 8 | ||
9 | from Wasserstein_Distance import (Wasserstein_Matcher, | 9 | from Wasserstein_Distance import (WassersteinMatcher, |
10 | clean_corpus_using_embeddings_vocabulary, | 10 | clean_corpus_using_embeddings_vocabulary, |
11 | load_embeddings) | 11 | load_embeddings) |
12 | 12 | ||
@@ -103,16 +103,13 @@ def main(args): | |||
103 | if not batch: | 103 | if not batch: |
104 | print(f'{metric}: {source_lang} - {target_lang}') | 104 | print(f'{metric}: {source_lang} - {target_lang}') |
105 | 105 | ||
106 | clf = Wasserstein_Matcher(W_embed=W_common, | 106 | clf = WassersteinMatcher(W_embed=W_common, |
107 | n_neighbors=5, | 107 | n_neighbors=5, |
108 | n_jobs=14, | 108 | n_jobs=14, |
109 | sinkhorn=(metric == 'snk')) | 109 | sinkhorn=(metric == 'snk')) |
110 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 110 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
111 | row_ind, col_ind, _ = clf.kneighbors(X_test_idf[:instances], | 111 | p_at_one, percentage = clf.align(X_test_idf[:instances], |
112 | n_neighbors=instances) | 112 | n_neighbors=instances) |
113 | result = zip(row_ind, col_ind) | ||
114 | p_at_one = len([x for x, y in result if x == y]) | ||
115 | percentage = p_at_one / instances * 100 | ||
116 | 113 | ||
117 | if not batch: | 114 | if not batch: |
118 | print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') | 115 | print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') |