diff options
| -rw-r--r-- | WMD_matching.py | 52 |
1 files changed, 26 insertions, 26 deletions
diff --git a/WMD_matching.py b/WMD_matching.py index 3b8b1a9..0b81696 100644 --- a/WMD_matching.py +++ b/WMD_matching.py | |||
| @@ -7,22 +7,6 @@ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | |||
| 7 | from sklearn.preprocessing import normalize | 7 | from sklearn.preprocessing import normalize |
| 8 | from Wass_Matcher import Wasserstein_Matcher | 8 | from Wass_Matcher import Wasserstein_Matcher |
| 9 | 9 | ||
| 10 | if __name__ == "__main__": | ||
| 11 | |||
| 12 | parser = argparse.ArgumentParser(description='matching using wmd and wasserstein distance') | ||
| 13 | parser.add_argument('source_lang', help='source language short name') | ||
| 14 | parser.add_argument('target_lang', help='target language short name') | ||
| 15 | parser.add_argument('source_vector', help='path of the source vector') | ||
| 16 | parser.add_argument('target_vector', help='path of the target vector') | ||
| 17 | parser.add_argument('source_defs', help='path of the source definitions') | ||
| 18 | parser.add_argument('target_defs', help='path of the target definitions') | ||
| 19 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') | ||
| 20 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') | ||
| 21 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) | ||
| 22 | args = parser.parse_args() | ||
| 23 | |||
| 24 | main(args) | ||
| 25 | |||
| 26 | def load_embeddings(path, dimension=300): | 10 | def load_embeddings(path, dimension=300): |
| 27 | """ | 11 | """ |
| 28 | Loads the embeddings from a word2vec formatted file. | 12 | Loads the embeddings from a word2vec formatted file. |
| @@ -88,6 +72,7 @@ def mrr_precision_at_k(golden, preds, k_list=[1,]): | |||
| 88 | 72 | ||
| 89 | def main(args): | 73 | def main(args): |
| 90 | 74 | ||
| 75 | numpy.seterr(divide='ignore') # POT has issues with divide by zero errors | ||
| 91 | source_lang = args.source_lang | 76 | source_lang = args.source_lang |
| 92 | target_lang = args.target_lang | 77 | target_lang = args.target_lang |
| 93 | 78 | ||
| @@ -162,7 +147,7 @@ def main(args): | |||
| 162 | if (not batch): | 147 | if (not batch): |
| 163 | print(f'WMD - tfidf: {source_lang} - {target_lang}') | 148 | print(f'WMD - tfidf: {source_lang} - {target_lang}') |
| 164 | 149 | ||
| 165 | clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14) | 150 | clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14) |
| 166 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 151 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
| 167 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 152 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
| 168 | result = zip(row_ind, col_ind) | 153 | result = zip(row_ind, col_ind) |
| @@ -183,20 +168,35 @@ def main(args): | |||
| 183 | if (not batch): | 168 | if (not batch): |
| 184 | print(f'Sinkhorn - tfidf: {source_lang} - {target_lang}') | 169 | print(f'Sinkhorn - tfidf: {source_lang} - {target_lang}') |
| 185 | 170 | ||
| 186 | clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True) | 171 | clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True) |
| 187 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 172 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
| 188 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 173 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
| 189 | 174 | ||
| 190 | result = zip(row_ind, col_ind) | 175 | result = zip(row_ind, col_ind) |
| 191 | hit_one = len([x for x,y in result if x == y]) | 176 | hit_one = len([x for x,y in result if x == y]) |
| 177 | percentage = hit_one / instances * 100 | ||
| 192 | 178 | ||
| 193 | if (not batch): | 179 | if (not batch): |
| 194 | print(f'{hit_one} definitions have been mapped correctly') | 180 | print(f'{hit_one} definitions have been mapped correctly, {percentage}%') |
| 195 | 181 | ||
| 182 | if (batch): | ||
| 183 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] | ||
| 184 | with open('sinkhorn_matching_result.csv', 'a') as f: | ||
| 185 | writer = csv.writer(f) | ||
| 186 | writer.writerow(fields) | ||
| 196 | 187 | ||
| 197 | if (batch): | 188 | if __name__ == "__main__": |
| 198 | percentage = hit_one / instances * 100 | 189 | |
| 199 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] | 190 | parser = argparse.ArgumentParser(description='matching using wmd and wasserstein distance') |
| 200 | with open('sinkhorn_matching_result.csv', 'a') as f: | 191 | parser.add_argument('source_lang', help='source language short name') |
| 201 | writer = csv.writer(f) | 192 | parser.add_argument('target_lang', help='target language short name') |
| 202 | writer.writerow(fields) | 193 | parser.add_argument('source_vector', help='path of the source vector') |
| 194 | parser.add_argument('target_vector', help='path of the target vector') | ||
| 195 | parser.add_argument('source_defs', help='path of the source definitions') | ||
| 196 | parser.add_argument('target_defs', help='path of the target definitions') | ||
| 197 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') | ||
| 198 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') | ||
| 199 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) | ||
| 200 | args = parser.parse_args() | ||
| 201 | |||
| 202 | main(args) | ||
