aboutsummaryrefslogtreecommitdiffstats
path: root/WMD_matching.py
diff options
context:
space:
mode:
Diffstat (limited to 'WMD_matching.py')
-rw-r--r--WMD_matching.py52
1 files changed, 26 insertions, 26 deletions
diff --git a/WMD_matching.py b/WMD_matching.py
index 3b8b1a9..0b81696 100644
--- a/WMD_matching.py
+++ b/WMD_matching.py
@@ -7,22 +7,6 @@ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
7from sklearn.preprocessing import normalize 7from sklearn.preprocessing import normalize
8from Wass_Matcher import Wasserstein_Matcher 8from Wass_Matcher import Wasserstein_Matcher
9 9
10if __name__ == "__main__":
11
12 parser = argparse.ArgumentParser(description='matching using wmd and wasserstein distance')
13 parser.add_argument('source_lang', help='source language short name')
14 parser.add_argument('target_lang', help='target language short name')
15 parser.add_argument('source_vector', help='path of the source vector')
16 parser.add_argument('target_vector', help='path of the target vector')
17 parser.add_argument('source_defs', help='path of the source definitions')
18 parser.add_argument('target_defs', help='path of the target definitions')
19 parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)')
20 parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run')
21 parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int)
22 args = parser.parse_args()
23
24 main(args)
25
26def load_embeddings(path, dimension=300): 10def load_embeddings(path, dimension=300):
27 """ 11 """
28 Loads the embeddings from a word2vec formatted file. 12 Loads the embeddings from a word2vec formatted file.
@@ -88,6 +72,7 @@ def mrr_precision_at_k(golden, preds, k_list=[1,]):
88 72
89def main(args): 73def main(args):
90 74
75 numpy.seterr(divide='ignore') # POT has issues with divide by zero errors
91 source_lang = args.source_lang 76 source_lang = args.source_lang
92 target_lang = args.target_lang 77 target_lang = args.target_lang
93 78
@@ -162,7 +147,7 @@ def main(args):
162 if (not batch): 147 if (not batch):
163 print(f'WMD - tfidf: {source_lang} - {target_lang}') 148 print(f'WMD - tfidf: {source_lang} - {target_lang}')
164 149
165 clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14) 150 clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14)
166 clf.fit(X_train_idf[:instances], np.ones(instances)) 151 clf.fit(X_train_idf[:instances], np.ones(instances))
167 row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) 152 row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
168 result = zip(row_ind, col_ind) 153 result = zip(row_ind, col_ind)
@@ -183,20 +168,35 @@ def main(args):
183 if (not batch): 168 if (not batch):
184 print(f'Sinkhorn - tfidf: {source_lang} - {target_lang}') 169 print(f'Sinkhorn - tfidf: {source_lang} - {target_lang}')
185 170
186 clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True) 171 clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True)
187 clf.fit(X_train_idf[:instances], np.ones(instances)) 172 clf.fit(X_train_idf[:instances], np.ones(instances))
188 row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) 173 row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
189 174
190 result = zip(row_ind, col_ind) 175 result = zip(row_ind, col_ind)
191 hit_one = len([x for x,y in result if x == y]) 176 hit_one = len([x for x,y in result if x == y])
177 percentage = hit_one / instances * 100
192 178
193 if (not batch): 179 if (not batch):
194 print(f'{hit_one} definitions have been mapped correctly') 180 print(f'{hit_one} definitions have been mapped correctly, {percentage}%')
195 181
182 if (batch):
183 fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}']
184 with open('sinkhorn_matching_result.csv', 'a') as f:
185 writer = csv.writer(f)
186 writer.writerow(fields)
196 187
197 if (batch): 188if __name__ == "__main__":
198 percentage = hit_one / instances * 100 189
199 fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}'] 190 parser = argparse.ArgumentParser(description='matching using wmd and wasserstein distance')
200 with open('sinkhorn_matching_result.csv', 'a') as f: 191 parser.add_argument('source_lang', help='source language short name')
201 writer = csv.writer(f) 192 parser.add_argument('target_lang', help='target language short name')
202 writer.writerow(fields) 193 parser.add_argument('source_vector', help='path of the source vector')
194 parser.add_argument('target_vector', help='path of the target vector')
195 parser.add_argument('source_defs', help='path of the source definitions')
196 parser.add_argument('target_defs', help='path of the target definitions')
197 parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)')
198 parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run')
199 parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int)
200 args = parser.parse_args()
201
202 main(args)