aboutsummaryrefslogtreecommitdiffstats
path: root/WMD_retrieval.py
diff options
context:
space:
mode:
authorYigit Sever2019-09-24 21:26:34 +0300
committerYigit Sever2019-09-24 21:26:34 +0300
commit49c6f58e51e12af691f7a1322137c64f46043b15 (patch)
treecb3709cd77af5d0f6a1df3c0e1904d0a781a39e8 /WMD_retrieval.py
parent5d9eab51b560d8cee828554cd2dd855037811e91 (diff)
downloadEvaluating-Dictionary-Alignment-49c6f58e51e12af691f7a1322137c64f46043b15.tar.gz
Evaluating-Dictionary-Alignment-49c6f58e51e12af691f7a1322137c64f46043b15.tar.bz2
Evaluating-Dictionary-Alignment-49c6f58e51e12af691f7a1322137c64f46043b15.zip
Use black linter for WMD
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r--WMD_retrieval.py97
1 files changed, 45 insertions, 52 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
index 02f35be..cb72079 100644
--- a/WMD_retrieval.py
+++ b/WMD_retrieval.py
@@ -13,7 +13,7 @@ from Wasserstein_Distance import (WassersteinRetriever,
13 13
14def main(args): 14def main(args):
15 15
16 np.seterr(divide='ignore') # POT has issues with divide by zero errors 16 np.seterr(divide="ignore") # POT has issues with divide by zero errors
17 source_lang = args.source_lang 17 source_lang = args.source_lang
18 target_lang = args.target_lang 18 target_lang = args.target_lang
19 19
@@ -29,32 +29,24 @@ def main(args):
29 mode = args.mode 29 mode = args.mode
30 runfor = list() 30 runfor = list()
31 31
32 if mode == 'all': 32 if mode == "all":
33 runfor.extend(['wmd', 'snk']) 33 runfor.extend(["wmd", "snk"])
34 else: 34 else:
35 runfor.append(mode) 35 runfor.append(mode)
36 36
37 defs_source = [ 37 defs_source = [
38 line.rstrip('\n') 38 line.rstrip("\n") for line in open(source_defs_filename, encoding="utf8")
39 for line in open(source_defs_filename, encoding='utf8')
40 ] 39 ]
41 defs_target = [ 40 defs_target = [
42 line.rstrip('\n') 41 line.rstrip("\n") for line in open(target_defs_filename, encoding="utf8")
43 for line in open(target_defs_filename, encoding='utf8')
44 ] 42 ]
45 43
46 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary( 44 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary(
47 set(vectors_source.keys()), 45 set(vectors_source.keys()), defs_source, vectors_source, source_lang
48 defs_source,
49 vectors_source,
50 source_lang,
51 ) 46 )
52 47
53 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary( 48 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary(
54 set(vectors_target.keys()), 49 set(vectors_target.keys()), defs_target, vectors_target, target_lang
55 defs_target,
56 vectors_target,
57 target_lang,
58 ) 50 )
59 51
60 take = args.instances 52 take = args.instances
@@ -70,14 +62,15 @@ def main(args):
70 62
71 if not batch: 63 if not batch:
72 print( 64 print(
73 f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}' 65 f"{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}"
74 ) 66 )
75 67
76 del vectors_source, vectors_target, defs_source, defs_target 68 del vectors_source, vectors_target, defs_source, defs_target
77 69
78 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus) 70 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus)
79 common = [ 71 common = [
80 word for word in vec.get_feature_names() 72 word
73 for word in vec.get_feature_names()
81 if word in clean_src_vectors or word in clean_target_vectors 74 if word in clean_src_vectors or word in clean_target_vectors
82 ] 75 ]
83 W_common = [] 76 W_common = []
@@ -88,9 +81,7 @@ def main(args):
88 W_common.append(np.array(clean_target_vectors[w])) 81 W_common.append(np.array(clean_target_vectors[w]))
89 82
90 if not batch: 83 if not batch:
91 print( 84 print(f"{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}")
92 f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}'
93 )
94 85
95 W_common = np.array(W_common) 86 W_common = np.array(W_common)
96 W_common = normalize(W_common) 87 W_common = normalize(W_common)
@@ -101,55 +92,57 @@ def main(args):
101 92
102 for metric in runfor: 93 for metric in runfor:
103 if not batch: 94 if not batch:
104 print(f'{metric}: {source_lang} - {target_lang}') 95 print(f"{metric}: {source_lang} - {target_lang}")
105 96
106 clf = WassersteinRetriever(W_embed=W_common, 97 clf = WassersteinRetriever(
107 n_neighbors=5, 98 W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == "snk")
108 n_jobs=14, 99 )
109 sinkhorn=(metric == 'snk'))
110 clf.fit(X_train_idf[:instances], np.ones(instances)) 100 clf.fit(X_train_idf[:instances], np.ones(instances))
111 p_at_one, percentage = clf.align(X_test_idf[:instances], 101 p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances)
112 n_neighbors=instances)
113 102
114 if not batch: 103 if not batch:
115 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') 104 print(f"P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%")
116 else: 105 else:
117 fields = [ 106 fields = [
118 f'{source_lang}', f'{target_lang}', f'{instances}', 107 f"{source_lang}",
119 f'{p_at_one}', f'{percentage}' 108 f"{target_lang}",
109 f"{instances}",
110 f"{p_at_one}",
111 f"{percentage}",
120 ] 112 ]
121 with open(f'{metric}_retrieval_result.csv', 'a') as f: 113 with open(f"{metric}_retrieval_result.csv", "a") as f:
122 writer = csv.writer(f) 114 writer = csv.writer(f)
123 writer.writerow(fields) 115 writer.writerow(fields)
124 116
125 117
126if __name__ == "__main__": 118if __name__ == "__main__":
127 119
128 parser = argparse.ArgumentParser( 120 parser = argparse.ArgumentParser(description="run retrieval using wmd or snk")
129 description='run retrieval using wmd or snk') 121 parser.add_argument("source_lang", help="source language short name")
130 parser.add_argument('source_lang', help='source language short name') 122 parser.add_argument("target_lang", help="target language short name")
131 parser.add_argument('target_lang', help='target language short name') 123 parser.add_argument("source_vector", help="path of the source vector")
132 parser.add_argument('source_vector', help='path of the source vector') 124 parser.add_argument("target_vector", help="path of the target vector")
133 parser.add_argument('target_vector', help='path of the target vector') 125 parser.add_argument("source_defs", help="path of the source definitions")
134 parser.add_argument('source_defs', help='path of the source definitions') 126 parser.add_argument("target_defs", help="path of the target definitions")
135 parser.add_argument('target_defs', help='path of the target definitions')
136 parser.add_argument( 127 parser.add_argument(
137 '-b', 128 "-b",
138 '--batch', 129 "--batch",
139 action='store_true', 130 action="store_true",
140 help= 131 help="running in batch (store results in csv) or running a single instance (output the results)",
141 'running in batch (store results in csv) or running a single instance (output the results)'
142 ) 132 )
143 parser.add_argument('mode',
144 choices=['all', 'wmd', 'snk'],
145 default='all',
146 help='which methods to run')
147 parser.add_argument( 133 parser.add_argument(
148 '-n', 134 "mode",
149 '--instances', 135 choices=["all", "wmd", "snk"],
150 help='number of instances in each language to retrieve', 136 default="all",
137 help="which methods to run",
138 )
139 parser.add_argument(
140 "-n",
141 "--instances",
142 help="number of instances in each language to retrieve",
151 default=1000, 143 default=1000,
152 type=int) 144 type=int,
145 )
153 146
154 args = parser.parse_args() 147 args = parser.parse_args()
155 148