diff options
| author | Yigit Sever | 2019-09-22 00:26:08 +0300 |
|---|---|---|
| committer | Yigit Sever | 2019-09-22 00:26:08 +0300 |
| commit | 22cc3c22d317c98940e5f38715266cf9757880b2 (patch) | |
| tree | be8743a32c75d148808e0aa6491323de7bb7faad /WMD_matching.py | |
| parent | 3c135215db79fac37ebede465db567395fa5daa5 (diff) | |
| download | Evaluating-Dictionary-Alignment-22cc3c22d317c98940e5f38715266cf9757880b2.tar.gz Evaluating-Dictionary-Alignment-22cc3c22d317c98940e5f38715266cf9757880b2.tar.bz2 Evaluating-Dictionary-Alignment-22cc3c22d317c98940e5f38715266cf9757880b2.zip | |
Roll the methods to a for, like retrieval
Diffstat (limited to 'WMD_matching.py')
| -rw-r--r-- | WMD_matching.py | 48 |
1 files changed, 17 insertions, 31 deletions
diff --git a/WMD_matching.py b/WMD_matching.py index 8a97389..59b64f9 100644 --- a/WMD_matching.py +++ b/WMD_matching.py | |||
| @@ -5,6 +5,7 @@ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | |||
| 5 | from sklearn.preprocessing import normalize | 5 | from sklearn.preprocessing import normalize |
| 6 | from Wasserstein_Distance import Wasserstein_Matcher | 6 | from Wasserstein_Distance import Wasserstein_Matcher |
| 7 | from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary | 7 | from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary |
| 8 | import csv | ||
| 8 | 9 | ||
| 9 | def main(args): | 10 | def main(args): |
| 10 | 11 | ||
| @@ -22,6 +23,12 @@ def main(args): | |||
| 22 | 23 | ||
| 23 | batch = args.batch | 24 | batch = args.batch |
| 24 | mode = args.mode | 25 | mode = args.mode |
| 26 | runfor = list() | ||
| 27 | |||
| 28 | if (mode == 'all'): | ||
| 29 | runfor.extend(['wmd','snk']) | ||
| 30 | else: | ||
| 31 | runfor.append(mode) | ||
| 25 | 32 | ||
| 26 | defs_source = [line.rstrip('\n') for line in open(source_defs_filename, encoding='utf8')] | 33 | defs_source = [line.rstrip('\n') for line in open(source_defs_filename, encoding='utf8')] |
| 27 | defs_target = [line.rstrip('\n') for line in open(target_defs_filename, encoding='utf8')] | 34 | defs_target = [line.rstrip('\n') for line in open(target_defs_filename, encoding='utf8')] |
| @@ -80,47 +87,25 @@ def main(args): | |||
| 80 | X_train_tf = vect_tf.transform(clean_src_corpus) | 87 | X_train_tf = vect_tf.transform(clean_src_corpus) |
| 81 | X_test_tf = vect_tf.transform(clean_target_corpus) | 88 | X_test_tf = vect_tf.transform(clean_target_corpus) |
| 82 | 89 | ||
| 83 | if (mode == 'wmd' or mode == 'all'): | 90 | for metric in runfor: |
| 84 | if (not batch): | 91 | if (not batch): |
| 85 | print(f'WMD - tfidf: {source_lang} - {target_lang}') | 92 | print(f'{metric}: {source_lang} - {target_lang}') |
| 86 | 93 | ||
| 87 | clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14) | 94 | clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) |
| 88 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 95 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
| 89 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 96 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
| 90 | result = zip(row_ind, col_ind) | 97 | result = zip(row_ind, col_ind) |
| 91 | hit_one = len([x for x,y in result if x == y]) | 98 | p_at_one = len([x for x,y in result if x == y]) |
| 92 | percentage = hit_one / instances * 100 | 99 | percentage = p_at_one / instances * 100 |
| 93 | 100 | ||
| 94 | if (not batch): | 101 | if (not batch): |
| 95 | print(f'{hit_one} definitions have been mapped correctly, {percentage}%') | 102 | print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') |
| 96 | 103 | else: | |
| 97 | if (batch): | 104 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{p_at_one}', f'{percentage}'] |
| 98 | import csv | 105 | with open(f'{metric}_matching_results.csv', 'a') as f: |
| 99 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] | ||
| 100 | with open('wmd_matching_results.csv', 'a') as f: | ||
| 101 | writer = csv.writer(f) | 106 | writer = csv.writer(f) |
| 102 | writer.writerow(fields) | 107 | writer.writerow(fields) |
| 103 | 108 | ||
| 104 | if (mode == 'snk' or mode == 'all'): | ||
| 105 | if (not batch): | ||
| 106 | print(f'Sinkhorn - tfidf: {source_lang} - {target_lang}') | ||
| 107 | |||
| 108 | clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True) | ||
| 109 | clf.fit(X_train_idf[:instances], np.ones(instances)) | ||
| 110 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | ||
| 111 | |||
| 112 | result = zip(row_ind, col_ind) | ||
| 113 | hit_one = len([x for x,y in result if x == y]) | ||
| 114 | percentage = hit_one / instances * 100 | ||
| 115 | |||
| 116 | if (not batch): | ||
| 117 | print(f'{hit_one} definitions have been mapped correctly, {percentage}%') | ||
| 118 | |||
| 119 | if (batch): | ||
| 120 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] | ||
| 121 | with open('sinkhorn_matching_result.csv', 'a') as f: | ||
| 122 | writer = csv.writer(f) | ||
| 123 | writer.writerow(fields) | ||
| 124 | 109 | ||
| 125 | if __name__ == "__main__": | 110 | if __name__ == "__main__": |
| 126 | 111 | ||
| @@ -134,6 +119,7 @@ if __name__ == "__main__": | |||
| 134 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') | 119 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') |
| 135 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') | 120 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') |
| 136 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) | 121 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) |
| 122 | |||
| 137 | args = parser.parse_args() | 123 | args = parser.parse_args() |
| 138 | 124 | ||
| 139 | main(args) | 125 | main(args) |
