diff options
-rw-r--r-- | WMD_matching.py | 48 |
1 files changed, 17 insertions, 31 deletions
diff --git a/WMD_matching.py b/WMD_matching.py index 8a97389..59b64f9 100644 --- a/WMD_matching.py +++ b/WMD_matching.py | |||
@@ -5,6 +5,7 @@ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | |||
5 | from sklearn.preprocessing import normalize | 5 | from sklearn.preprocessing import normalize |
6 | from Wasserstein_Distance import Wasserstein_Matcher | 6 | from Wasserstein_Distance import Wasserstein_Matcher |
7 | from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary | 7 | from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary |
8 | import csv | ||
8 | 9 | ||
9 | def main(args): | 10 | def main(args): |
10 | 11 | ||
@@ -22,6 +23,12 @@ def main(args): | |||
22 | 23 | ||
23 | batch = args.batch | 24 | batch = args.batch |
24 | mode = args.mode | 25 | mode = args.mode |
26 | runfor = list() | ||
27 | |||
28 | if (mode == 'all'): | ||
29 | runfor.extend(['wmd','snk']) | ||
30 | else: | ||
31 | runfor.append(mode) | ||
25 | 32 | ||
26 | defs_source = [line.rstrip('\n') for line in open(source_defs_filename, encoding='utf8')] | 33 | defs_source = [line.rstrip('\n') for line in open(source_defs_filename, encoding='utf8')] |
27 | defs_target = [line.rstrip('\n') for line in open(target_defs_filename, encoding='utf8')] | 34 | defs_target = [line.rstrip('\n') for line in open(target_defs_filename, encoding='utf8')] |
@@ -80,47 +87,25 @@ def main(args): | |||
80 | X_train_tf = vect_tf.transform(clean_src_corpus) | 87 | X_train_tf = vect_tf.transform(clean_src_corpus) |
81 | X_test_tf = vect_tf.transform(clean_target_corpus) | 88 | X_test_tf = vect_tf.transform(clean_target_corpus) |
82 | 89 | ||
83 | if (mode == 'wmd' or mode == 'all'): | 90 | for metric in runfor: |
84 | if (not batch): | 91 | if (not batch): |
85 | print(f'WMD - tfidf: {source_lang} - {target_lang}') | 92 | print(f'{metric}: {source_lang} - {target_lang}') |
86 | 93 | ||
87 | clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14) | 94 | clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) |
88 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 95 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
89 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 96 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
90 | result = zip(row_ind, col_ind) | 97 | result = zip(row_ind, col_ind) |
91 | hit_one = len([x for x,y in result if x == y]) | 98 | p_at_one = len([x for x,y in result if x == y]) |
92 | percentage = hit_one / instances * 100 | 99 | percentage = p_at_one / instances * 100 |
93 | 100 | ||
94 | if (not batch): | 101 | if (not batch): |
95 | print(f'{hit_one} definitions have been mapped correctly, {percentage}%') | 102 | print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') |
96 | 103 | else: | |
97 | if (batch): | 104 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{p_at_one}', f'{percentage}'] |
98 | import csv | 105 | with open(f'{metric}_matching_results.csv', 'a') as f: |
99 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] | ||
100 | with open('wmd_matching_results.csv', 'a') as f: | ||
101 | writer = csv.writer(f) | 106 | writer = csv.writer(f) |
102 | writer.writerow(fields) | 107 | writer.writerow(fields) |
103 | 108 | ||
104 | if (mode == 'snk' or mode == 'all'): | ||
105 | if (not batch): | ||
106 | print(f'Sinkhorn - tfidf: {source_lang} - {target_lang}') | ||
107 | |||
108 | clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True) | ||
109 | clf.fit(X_train_idf[:instances], np.ones(instances)) | ||
110 | row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | ||
111 | |||
112 | result = zip(row_ind, col_ind) | ||
113 | hit_one = len([x for x,y in result if x == y]) | ||
114 | percentage = hit_one / instances * 100 | ||
115 | |||
116 | if (not batch): | ||
117 | print(f'{hit_one} definitions have been mapped correctly, {percentage}%') | ||
118 | |||
119 | if (batch): | ||
120 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] | ||
121 | with open('sinkhorn_matching_result.csv', 'a') as f: | ||
122 | writer = csv.writer(f) | ||
123 | writer.writerow(fields) | ||
124 | 109 | ||
125 | if __name__ == "__main__": | 110 | if __name__ == "__main__": |
126 | 111 | ||
@@ -134,6 +119,7 @@ if __name__ == "__main__": | |||
134 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') | 119 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') |
135 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') | 120 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') |
136 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) | 121 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) |
122 | |||
137 | args = parser.parse_args() | 123 | args = parser.parse_args() |
138 | 124 | ||
139 | main(args) | 125 | main(args) |