aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--WMD.py (renamed from WMD_matching.py)94
-rw-r--r--WMD_retrieval.py149
-rw-r--r--Wasserstein_Distance.py4
3 files changed, 60 insertions, 187 deletions
diff --git a/WMD_matching.py b/WMD.py
index 69ea10e..dd43cd5 100644
--- a/WMD_matching.py
+++ b/WMD.py
@@ -6,9 +6,8 @@ import numpy as np
6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
7from sklearn.preprocessing import normalize 7from sklearn.preprocessing import normalize
8 8
9from Wasserstein_Distance import (WassersteinMatcher, 9from Wasserstein_Distance import (WassersteinMatcher, WassersteinRetriever,
10 clean_corpus_using_embeddings_vocabulary, 10 load_embeddings, process_corpus)
11 load_embeddings)
12 11
13 12
14def main(args): 13def main(args):
@@ -26,13 +25,21 @@ def main(args):
26 target_defs_filename = args.target_defs 25 target_defs_filename = args.target_defs
27 26
28 batch = args.batch 27 batch = args.batch
29 mode = args.mode 28 input_mode = args.mode
30 runfor = list() 29 input_paradigm = args.paradigm
31 30
32 if mode == "all": 31 run_method = list()
33 runfor.extend(["wmd", "snk"]) 32 run_paradigm = list()
33
34 if input_paradigm == "all":
35 run_paradigm.extend("matching", "retrieval")
36 else:
37 run_paradigm.append(input_paradigm)
38
39 if input_mode == "all":
40 run_method.extend(["wmd", "snk"])
34 else: 41 else:
35 runfor.append(mode) 42 run_method.append(input_mode)
36 43
37 defs_source = [ 44 defs_source = [
38 line.rstrip("\n") for line in open(source_defs_filename, encoding="utf8") 45 line.rstrip("\n") for line in open(source_defs_filename, encoding="utf8")
@@ -41,11 +48,11 @@ def main(args):
41 line.rstrip("\n") for line in open(target_defs_filename, encoding="utf8") 48 line.rstrip("\n") for line in open(target_defs_filename, encoding="utf8")
42 ] 49 ]
43 50
44 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary( 51 clean_src_corpus, clean_src_vectors, src_keys = process_corpus(
45 set(vectors_source.keys()), defs_source, vectors_source, source_lang 52 set(vectors_source.keys()), defs_source, vectors_source, source_lang
46 ) 53 )
47 54
48 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary( 55 clean_target_corpus, clean_target_vectors, target_keys = process_corpus(
49 set(vectors_target.keys()), defs_target, vectors_target, target_lang 56 set(vectors_target.keys()), defs_target, vectors_target, target_lang
50 ) 57 )
51 58
@@ -62,7 +69,8 @@ def main(args):
62 69
63 if not batch: 70 if not batch:
64 print( 71 print(
65 f"{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}" 72 f"{source_lang} - {target_lang} "
73 + f" document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}"
66 ) 74 )
67 75
68 del vectors_source, vectors_target, defs_source, defs_target 76 del vectors_source, vectors_target, defs_source, defs_target
@@ -90,35 +98,44 @@ def main(args):
90 X_train_idf = vect.transform(clean_src_corpus) 98 X_train_idf = vect.transform(clean_src_corpus)
91 X_test_idf = vect.transform(clean_target_corpus) 99 X_test_idf = vect.transform(clean_target_corpus)
92 100
93 for metric in runfor: 101 for paradigm in run_paradigm:
94 if not batch: 102 WassersteinDriver = None
95 print(f"{metric}: {source_lang} - {target_lang}") 103 if paradigm == "matching":
96 104 WassersteinDriver = WassersteinMatcher
97 clf = WassersteinMatcher(
98 W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == "snk")
99 )
100 clf.fit(X_train_idf[:instances], np.ones(instances))
101 p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances)
102
103 if not batch:
104 print(f"P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%")
105 else: 105 else:
106 fields = [ 106 WassersteinDriver = WassersteinRetriever
107 f"{source_lang}", 107
108 f"{target_lang}", 108 for metric in run_method:
109 f"{instances}", 109 if not batch:
110 f"{p_at_one}", 110 print(f"{metric}: {source_lang} - {target_lang}")
111 f"{percentage}", 111
112 ] 112 clf = WassersteinDriver(
113 with open(f"{metric}_matching_results.csv", "a") as f: 113 W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == "snk")
114 writer = csv.writer(f) 114 )
115 writer.writerow(fields) 115 clf.fit(X_train_idf[:instances], np.ones(instances))
116 p_at_one, percentage = clf.align(
117 X_test_idf[:instances], n_neighbors=instances
118 )
119
120 if not batch:
121 print(f"P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%")
122 else:
123 fields = [
124 f"{source_lang}",
125 f"{target_lang}",
126 f"{instances}",
127 f"{p_at_one}",
128 f"{percentage}",
129 ]
130 with open(f"{metric}_{paradigm}_results.csv", "a") as f:
131 writer = csv.writer(f)
132 writer.writerow(fields)
116 133
117 134
118if __name__ == "__main__": 135if __name__ == "__main__":
119 136
120 parser = argparse.ArgumentParser( 137 parser = argparse.ArgumentParser(
121 description="matching using wmd and wasserstein distance" 138 description="align dictionaries using wmd and wasserstein distance"
122 ) 139 )
123 parser.add_argument("source_lang", help="source language short name") 140 parser.add_argument("source_lang", help="source language short name")
124 parser.add_argument("target_lang", help="target language short name") 141 parser.add_argument("target_lang", help="target language short name")
@@ -130,7 +147,8 @@ if __name__ == "__main__":
130 "-b", 147 "-b",
131 "--batch", 148 "--batch",
132 action="store_true", 149 action="store_true",
133 help="running in batch (store results in csv) or running a single instance (output the results)", 150 help="running in batch (store results in csv) or"
151 + "running a single instance (output the results)",
134 ) 152 )
135 parser.add_argument( 153 parser.add_argument(
136 "mode", 154 "mode",
@@ -139,6 +157,12 @@ if __name__ == "__main__":
139 help="which methods to run", 157 help="which methods to run",
140 ) 158 )
141 parser.add_argument( 159 parser.add_argument(
160 "paradigm",
161 choices=["all", "retrieval", "matching"],
162 default="all",
163 help="which paradigms to align with",
164 )
165 parser.add_argument(
142 "-n", 166 "-n",
143 "--instances", 167 "--instances",
144 help="number of instances in each language to retrieve", 168 help="number of instances in each language to retrieve",
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
deleted file mode 100644
index cb72079..0000000
--- a/WMD_retrieval.py
+++ /dev/null
@@ -1,149 +0,0 @@
1import argparse
2import csv
3import random
4
5import numpy as np
6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
7from sklearn.preprocessing import normalize
8
9from Wasserstein_Distance import (WassersteinRetriever,
10 clean_corpus_using_embeddings_vocabulary,
11 load_embeddings)
12
13
14def main(args):
15
16 np.seterr(divide="ignore") # POT has issues with divide by zero errors
17 source_lang = args.source_lang
18 target_lang = args.target_lang
19
20 source_vectors_filename = args.source_vector
21 target_vectors_filename = args.target_vector
22 vectors_source = load_embeddings(source_vectors_filename)
23 vectors_target = load_embeddings(target_vectors_filename)
24
25 source_defs_filename = args.source_defs
26 target_defs_filename = args.target_defs
27
28 batch = args.batch
29 mode = args.mode
30 runfor = list()
31
32 if mode == "all":
33 runfor.extend(["wmd", "snk"])
34 else:
35 runfor.append(mode)
36
37 defs_source = [
38 line.rstrip("\n") for line in open(source_defs_filename, encoding="utf8")
39 ]
40 defs_target = [
41 line.rstrip("\n") for line in open(target_defs_filename, encoding="utf8")
42 ]
43
44 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary(
45 set(vectors_source.keys()), defs_source, vectors_source, source_lang
46 )
47
48 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary(
49 set(vectors_target.keys()), defs_target, vectors_target, target_lang
50 )
51
52 take = args.instances
53
54 common_keys = set(src_keys).intersection(set(target_keys))
55 take = min(len(common_keys), take) # you can't sample more than length
56 experiment_keys = random.sample(common_keys, take)
57
58 instances = len(experiment_keys)
59
60 clean_src_corpus = list(clean_src_corpus[experiment_keys])
61 clean_target_corpus = list(clean_target_corpus[experiment_keys])
62
63 if not batch:
64 print(
65 f"{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}"
66 )
67
68 del vectors_source, vectors_target, defs_source, defs_target
69
70 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus)
71 common = [
72 word
73 for word in vec.get_feature_names()
74 if word in clean_src_vectors or word in clean_target_vectors
75 ]
76 W_common = []
77 for w in common:
78 if w in clean_src_vectors:
79 W_common.append(np.array(clean_src_vectors[w]))
80 else:
81 W_common.append(np.array(clean_target_vectors[w]))
82
83 if not batch:
84 print(f"{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}")
85
86 W_common = np.array(W_common)
87 W_common = normalize(W_common)
88 vect = TfidfVectorizer(vocabulary=common, dtype=np.double, norm=None)
89 vect.fit(clean_src_corpus + clean_target_corpus)
90 X_train_idf = vect.transform(clean_src_corpus)
91 X_test_idf = vect.transform(clean_target_corpus)
92
93 for metric in runfor:
94 if not batch:
95 print(f"{metric}: {source_lang} - {target_lang}")
96
97 clf = WassersteinRetriever(
98 W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == "snk")
99 )
100 clf.fit(X_train_idf[:instances], np.ones(instances))
101 p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances)
102
103 if not batch:
104 print(f"P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%")
105 else:
106 fields = [
107 f"{source_lang}",
108 f"{target_lang}",
109 f"{instances}",
110 f"{p_at_one}",
111 f"{percentage}",
112 ]
113 with open(f"{metric}_retrieval_result.csv", "a") as f:
114 writer = csv.writer(f)
115 writer.writerow(fields)
116
117
118if __name__ == "__main__":
119
120 parser = argparse.ArgumentParser(description="run retrieval using wmd or snk")
121 parser.add_argument("source_lang", help="source language short name")
122 parser.add_argument("target_lang", help="target language short name")
123 parser.add_argument("source_vector", help="path of the source vector")
124 parser.add_argument("target_vector", help="path of the target vector")
125 parser.add_argument("source_defs", help="path of the source definitions")
126 parser.add_argument("target_defs", help="path of the target definitions")
127 parser.add_argument(
128 "-b",
129 "--batch",
130 action="store_true",
131 help="running in batch (store results in csv) or running a single instance (output the results)",
132 )
133 parser.add_argument(
134 "mode",
135 choices=["all", "wmd", "snk"],
136 default="all",
137 help="which methods to run",
138 )
139 parser.add_argument(
140 "-n",
141 "--instances",
142 help="number of instances in each language to retrieve",
143 default=1000,
144 type=int,
145 )
146
147 args = parser.parse_args()
148
149 main(args)
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py
index 60991b9..cca2fac 100644
--- a/Wasserstein_Distance.py
+++ b/Wasserstein_Distance.py
@@ -225,9 +225,7 @@ def load_embeddings(path, dimension=300):
225 return vectors 225 return vectors
226 226
227 227
228def clean_corpus_using_embeddings_vocabulary( 228def process_corpus(embeddings_dictionary, corpus, vectors, language):
229 embeddings_dictionary, corpus, vectors, language
230):
231 """ 229 """
232 Cleans corpus using the dictionary of embeddings. 230 Cleans corpus using the dictionary of embeddings.
233 Any word without an associated embedding in the dictionary is ignored. 231 Any word without an associated embedding in the dictionary is ignored.