aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYigit Sever2019-09-22 00:26:08 +0300
committerYigit Sever2019-09-22 00:26:08 +0300
commit22cc3c22d317c98940e5f38715266cf9757880b2 (patch)
treebe8743a32c75d148808e0aa6491323de7bb7faad
parent3c135215db79fac37ebede465db567395fa5daa5 (diff)
downloadEvaluating-Dictionary-Alignment-22cc3c22d317c98940e5f38715266cf9757880b2.tar.gz
Evaluating-Dictionary-Alignment-22cc3c22d317c98940e5f38715266cf9757880b2.tar.bz2
Evaluating-Dictionary-Alignment-22cc3c22d317c98940e5f38715266cf9757880b2.zip
Roll the methods to a for, like retrieval
-rw-r--r--WMD_matching.py48
1 files changed, 17 insertions, 31 deletions
diff --git a/WMD_matching.py b/WMD_matching.py
index 8a97389..59b64f9 100644
--- a/WMD_matching.py
+++ b/WMD_matching.py
@@ -5,6 +5,7 @@ from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
5from sklearn.preprocessing import normalize 5from sklearn.preprocessing import normalize
6from Wasserstein_Distance import Wasserstein_Matcher 6from Wasserstein_Distance import Wasserstein_Matcher
7from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary 7from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary
8import csv
8 9
9def main(args): 10def main(args):
10 11
@@ -22,6 +23,12 @@ def main(args):
22 23
23 batch = args.batch 24 batch = args.batch
24 mode = args.mode 25 mode = args.mode
26 runfor = list()
27
28 if (mode == 'all'):
29 runfor.extend(['wmd','snk'])
30 else:
31 runfor.append(mode)
25 32
26 defs_source = [line.rstrip('\n') for line in open(source_defs_filename, encoding='utf8')] 33 defs_source = [line.rstrip('\n') for line in open(source_defs_filename, encoding='utf8')]
27 defs_target = [line.rstrip('\n') for line in open(target_defs_filename, encoding='utf8')] 34 defs_target = [line.rstrip('\n') for line in open(target_defs_filename, encoding='utf8')]
@@ -80,47 +87,25 @@ def main(args):
80 X_train_tf = vect_tf.transform(clean_src_corpus) 87 X_train_tf = vect_tf.transform(clean_src_corpus)
81 X_test_tf = vect_tf.transform(clean_target_corpus) 88 X_test_tf = vect_tf.transform(clean_target_corpus)
82 89
83 if (mode == 'wmd' or mode == 'all'): 90 for metric in runfor:
84 if (not batch): 91 if (not batch):
85 print(f'WMD - tfidf: {source_lang} - {target_lang}') 92 print(f'{metric}: {source_lang} - {target_lang}')
86 93
87 clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14) 94 clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk'))
88 clf.fit(X_train_idf[:instances], np.ones(instances)) 95 clf.fit(X_train_idf[:instances], np.ones(instances))
89 row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) 96 row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
90 result = zip(row_ind, col_ind) 97 result = zip(row_ind, col_ind)
91 hit_one = len([x for x,y in result if x == y]) 98 p_at_one = len([x for x,y in result if x == y])
92 percentage = hit_one / instances * 100 99 percentage = p_at_one / instances * 100
93 100
94 if (not batch): 101 if (not batch):
95 print(f'{hit_one} definitions have been mapped correctly, {percentage}%') 102 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%')
96 103 else:
97 if (batch): 104 fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{p_at_one}', f'{percentage}']
98 import csv 105 with open(f'{metric}_matching_results.csv', 'a') as f:
99 fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}']
100 with open('wmd_matching_results.csv', 'a') as f:
101 writer = csv.writer(f) 106 writer = csv.writer(f)
102 writer.writerow(fields) 107 writer.writerow(fields)
103 108
104 if (mode == 'snk' or mode == 'all'):
105 if (not batch):
106 print(f'Sinkhorn - tfidf: {source_lang} - {target_lang}')
107
108 clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=True)
109 clf.fit(X_train_idf[:instances], np.ones(instances))
110 row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
111
112 result = zip(row_ind, col_ind)
113 hit_one = len([x for x,y in result if x == y])
114 percentage = hit_one / instances * 100
115
116 if (not batch):
117 print(f'{hit_one} definitions have been mapped correctly, {percentage}%')
118
119 if (batch):
120 fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{hit_one}', f'{percentage}']
121 with open('sinkhorn_matching_result.csv', 'a') as f:
122 writer = csv.writer(f)
123 writer.writerow(fields)
124 109
125if __name__ == "__main__": 110if __name__ == "__main__":
126 111
@@ -134,6 +119,7 @@ if __name__ == "__main__":
134 parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') 119 parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)')
135 parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') 120 parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run')
136 parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) 121 parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int)
122
137 args = parser.parse_args() 123 args = parser.parse_args()
138 124
139 main(args) 125 main(args)