diff options
-rw-r--r-- | WMD_matching.py | 52 |
1 files changed, 26 insertions, 26 deletions
diff --git a/WMD_matching.py b/WMD_matching.py index 3b8b1a9..0b81696 100644 --- a/WMD_matching.py +++ b/WMD_matching.py | |||
@@ -7,22 +7,6 @@ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | |||
7 | from sklearn.preprocessing import normalize | 7 | from sklearn.preprocessing import normalize |
8 | from Wass_Matcher import Wasserstein_Matcher | 8 | from Wass_Matcher import Wasserstein_Matcher |
9 | 9 | ||
10 | if __name__ == "__main__": | ||
11 | |||
12 | parser = argparse.ArgumentParser(description='matching using wmd and wasserstein distance') | ||
13 | parser.add_argument('source_lang', help='source language short name') | ||
14 | parser.add_argument('target_lang', help='target language short name') | ||
15 | parser.add_argument('source_vector', help='path of the source vector') | ||
16 | parser.add_argument('target_vector', help='path of the target vector') | ||
17 | parser.add_argument('source_defs', help='path of the source definitions') | ||
18 | parser.add_argument('target_defs', help='path of the target definitions') | ||
19 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') | ||
20 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') | ||
21 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) | ||
22 | args = parser.parse_args() | ||
23 | |||
24 | main(args) | ||
25 | |||
26 | def load_embeddings(path, dimension=300): | 10 | def load_embeddings(path, dimension=300): |
27 | """ | 11 | """ |
28 | Loads the embeddings from a word2vec formatted file. | 12 | Loads the embeddings from a word2vec formatted file. |
@@ -88,6 +72,7 @@ def mrr_precision_at_k(golden, preds, k_list=[1,]): | |||
88 | 72 | ||
89 | def main(args): | 73 | def main(args): |
90 | 74 | ||
75 | numpy.seterr(divide='ignore') # POT has issues with divide by zero errors | ||
91 | source_lang = args.source_lang | 76 | source_lang = args.source_lang |
92 | target_lang = args.target_lang | 77 | target_lang = args.target_lang |
93 | 78 | ||
@@ -162,7 +147,7 @@ def main(args): | |||
162 | if (not batch): | 147 | if (not batch): |
163 | print(f'WMD - tfidf: {source_lang} - {target_lang}') | 148 | print(f'WMD - tfidf: {source_lang} - {target_lang}') |
164 | 149 | ||
165 | clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14) | 150 | clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14) |
166 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 151 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
167 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 152 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
168 | result = zip(row_ind, col_ind) | 153 | result = zip(row_ind, col_ind) |
@@ -183,20 +168,35 @@ def main(args): | |||
183 | if (not batch): | 168 | if (not batch): |
184 | print(f'Sinkhorn - tfidf: {source_lang} - {target_lang}') | 169 | print(f'Sinkhorn - tfidf: {source_lang} - {target_lang}') |
185 | 170 | ||
186 | clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True) | 171 | clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True) |
187 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 172 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
188 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 173 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
189 | 174 | ||
190 | result = zip(row_ind, col_ind) | 175 | result = zip(row_ind, col_ind) |
191 | hit_one = len([x for x,y in result if x == y]) | 176 | hit_one = len([x for x,y in result if x == y]) |
177 | percentage = hit_one / instances * 100 | ||
192 | 178 | ||
193 | if (not batch): | 179 | if (not batch): |
194 | print(f'{hit_one} definitions have been mapped correctly') | 180 | print(f'{hit_one} definitions have been mapped correctly, {percentage}%') |
195 | 181 | ||
182 | if (batch): | ||
183 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] | ||
184 | with open('sinkhorn_matching_result.csv', 'a') as f: | ||
185 | writer = csv.writer(f) | ||
186 | writer.writerow(fields) | ||
196 | 187 | ||
197 | if (batch): | 188 | if __name__ == "__main__": |
198 | percentage = hit_one / instances * 100 | 189 | |
199 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] | 190 | parser = argparse.ArgumentParser(description='matching using wmd and wasserstein distance') |
200 | with open('sinkhorn_matching_result.csv', 'a') as f: | 191 | parser.add_argument('source_lang', help='source language short name') |
201 | writer = csv.writer(f) | 192 | parser.add_argument('target_lang', help='target language short name') |
202 | writer.writerow(fields) | 193 | parser.add_argument('source_vector', help='path of the source vector') |
194 | parser.add_argument('target_vector', help='path of the target vector') | ||
195 | parser.add_argument('source_defs', help='path of the source definitions') | ||
196 | parser.add_argument('target_defs', help='path of the target definitions') | ||
197 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') | ||
198 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') | ||
199 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) | ||
200 | args = parser.parse_args() | ||
201 | |||
202 | main(args) | ||