aboutsummaryrefslogtreecommitdiffstats
path: root/WMD.py
diff options
context:
space:
mode:
Diffstat (limited to 'WMD.py')
-rw-r--r--WMD.py175
1 files changed, 175 insertions, 0 deletions
diff --git a/WMD.py b/WMD.py
new file mode 100644
index 0000000..dd43cd5
--- /dev/null
+++ b/WMD.py
@@ -0,0 +1,175 @@
1import argparse
2import csv
3import random
4
5import numpy as np
6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
7from sklearn.preprocessing import normalize
8
9from Wasserstein_Distance import (WassersteinMatcher, WassersteinRetriever,
10 load_embeddings, process_corpus)
11
12
13def main(args):
14
15 np.seterr(divide="ignore") # POT has issues with divide by zero errors
16 source_lang = args.source_lang
17 target_lang = args.target_lang
18
19 source_vectors_filename = args.source_vector
20 target_vectors_filename = args.target_vector
21 vectors_source = load_embeddings(source_vectors_filename)
22 vectors_target = load_embeddings(target_vectors_filename)
23
24 source_defs_filename = args.source_defs
25 target_defs_filename = args.target_defs
26
27 batch = args.batch
28 input_mode = args.mode
29 input_paradigm = args.paradigm
30
31 run_method = list()
32 run_paradigm = list()
33
34 if input_paradigm == "all":
35 run_paradigm.extend("matching", "retrieval")
36 else:
37 run_paradigm.append(input_paradigm)
38
39 if input_mode == "all":
40 run_method.extend(["wmd", "snk"])
41 else:
42 run_method.append(input_mode)
43
44 defs_source = [
45 line.rstrip("\n") for line in open(source_defs_filename, encoding="utf8")
46 ]
47 defs_target = [
48 line.rstrip("\n") for line in open(target_defs_filename, encoding="utf8")
49 ]
50
51 clean_src_corpus, clean_src_vectors, src_keys = process_corpus(
52 set(vectors_source.keys()), defs_source, vectors_source, source_lang
53 )
54
55 clean_target_corpus, clean_target_vectors, target_keys = process_corpus(
56 set(vectors_target.keys()), defs_target, vectors_target, target_lang
57 )
58
59 take = args.instances
60
61 common_keys = set(src_keys).intersection(set(target_keys))
62 take = min(len(common_keys), take) # you can't sample more than length
63 experiment_keys = random.sample(common_keys, take)
64
65 instances = len(experiment_keys)
66
67 clean_src_corpus = list(clean_src_corpus[experiment_keys])
68 clean_target_corpus = list(clean_target_corpus[experiment_keys])
69
70 if not batch:
71 print(
72 f"{source_lang} - {target_lang} "
73 + f" document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}"
74 )
75
76 del vectors_source, vectors_target, defs_source, defs_target
77
78 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus)
79 common = [
80 word
81 for word in vec.get_feature_names()
82 if word in clean_src_vectors or word in clean_target_vectors
83 ]
84 W_common = []
85 for w in common:
86 if w in clean_src_vectors:
87 W_common.append(np.array(clean_src_vectors[w]))
88 else:
89 W_common.append(np.array(clean_target_vectors[w]))
90
91 if not batch:
92 print(f"{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}")
93
94 W_common = np.array(W_common)
95 W_common = normalize(W_common)
96 vect = TfidfVectorizer(vocabulary=common, dtype=np.double, norm=None)
97 vect.fit(clean_src_corpus + clean_target_corpus)
98 X_train_idf = vect.transform(clean_src_corpus)
99 X_test_idf = vect.transform(clean_target_corpus)
100
101 for paradigm in run_paradigm:
102 WassersteinDriver = None
103 if paradigm == "matching":
104 WassersteinDriver = WassersteinMatcher
105 else:
106 WassersteinDriver = WassersteinRetriever
107
108 for metric in run_method:
109 if not batch:
110 print(f"{metric}: {source_lang} - {target_lang}")
111
112 clf = WassersteinDriver(
113 W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == "snk")
114 )
115 clf.fit(X_train_idf[:instances], np.ones(instances))
116 p_at_one, percentage = clf.align(
117 X_test_idf[:instances], n_neighbors=instances
118 )
119
120 if not batch:
121 print(f"P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%")
122 else:
123 fields = [
124 f"{source_lang}",
125 f"{target_lang}",
126 f"{instances}",
127 f"{p_at_one}",
128 f"{percentage}",
129 ]
130 with open(f"{metric}_{paradigm}_results.csv", "a") as f:
131 writer = csv.writer(f)
132 writer.writerow(fields)
133
134
135if __name__ == "__main__":
136
137 parser = argparse.ArgumentParser(
138 description="align dictionaries using wmd and wasserstein distance"
139 )
140 parser.add_argument("source_lang", help="source language short name")
141 parser.add_argument("target_lang", help="target language short name")
142 parser.add_argument("source_vector", help="path of the source vector")
143 parser.add_argument("target_vector", help="path of the target vector")
144 parser.add_argument("source_defs", help="path of the source definitions")
145 parser.add_argument("target_defs", help="path of the target definitions")
146 parser.add_argument(
147 "-b",
148 "--batch",
149 action="store_true",
150 help="running in batch (store results in csv) or"
151 + "running a single instance (output the results)",
152 )
153 parser.add_argument(
154 "mode",
155 choices=["all", "wmd", "snk"],
156 default="all",
157 help="which methods to run",
158 )
159 parser.add_argument(
160 "paradigm",
161 choices=["all", "retrieval", "matching"],
162 default="all",
163 help="which paradigms to align with",
164 )
165 parser.add_argument(
166 "-n",
167 "--instances",
168 help="number of instances in each language to retrieve",
169 default=1000,
170 type=int,
171 )
172
173 args = parser.parse_args()
174
175 main(args)