aboutsummaryrefslogtreecommitdiffstats
path: root/WMD_retrieval.py
diff options
context:
space:
mode:
authorYigit Sever2019-09-25 14:21:44 +0300
committerYigit Sever2019-09-25 14:21:44 +0300
commitc74318070ad85d5d7943e96d343aa961db305316 (patch)
tree2669ec6ba06b4080bcd310581bd216a88387d2bc /WMD_retrieval.py
parent49c6f58e51e12af691f7a1322137c64f46043b15 (diff)
downloadEvaluating-Dictionary-Alignment-c74318070ad85d5d7943e96d343aa961db305316.tar.gz
Evaluating-Dictionary-Alignment-c74318070ad85d5d7943e96d343aa961db305316.tar.bz2
Evaluating-Dictionary-Alignment-c74318070ad85d5d7943e96d343aa961db305316.zip
Merge WMD/SNK matching and retrieval
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r--WMD_retrieval.py149
1 files changed, 0 insertions, 149 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
deleted file mode 100644
index cb72079..0000000
--- a/WMD_retrieval.py
+++ /dev/null
@@ -1,149 +0,0 @@
1import argparse
2import csv
3import random
4
5import numpy as np
6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
7from sklearn.preprocessing import normalize
8
9from Wasserstein_Distance import (WassersteinRetriever,
10 clean_corpus_using_embeddings_vocabulary,
11 load_embeddings)
12
13
14def main(args):
15
16 np.seterr(divide="ignore") # POT has issues with divide by zero errors
17 source_lang = args.source_lang
18 target_lang = args.target_lang
19
20 source_vectors_filename = args.source_vector
21 target_vectors_filename = args.target_vector
22 vectors_source = load_embeddings(source_vectors_filename)
23 vectors_target = load_embeddings(target_vectors_filename)
24
25 source_defs_filename = args.source_defs
26 target_defs_filename = args.target_defs
27
28 batch = args.batch
29 mode = args.mode
30 runfor = list()
31
32 if mode == "all":
33 runfor.extend(["wmd", "snk"])
34 else:
35 runfor.append(mode)
36
37 defs_source = [
38 line.rstrip("\n") for line in open(source_defs_filename, encoding="utf8")
39 ]
40 defs_target = [
41 line.rstrip("\n") for line in open(target_defs_filename, encoding="utf8")
42 ]
43
44 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary(
45 set(vectors_source.keys()), defs_source, vectors_source, source_lang
46 )
47
48 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary(
49 set(vectors_target.keys()), defs_target, vectors_target, target_lang
50 )
51
52 take = args.instances
53
54 common_keys = set(src_keys).intersection(set(target_keys))
55 take = min(len(common_keys), take) # you can't sample more than length
56 experiment_keys = random.sample(common_keys, take)
57
58 instances = len(experiment_keys)
59
60 clean_src_corpus = list(clean_src_corpus[experiment_keys])
61 clean_target_corpus = list(clean_target_corpus[experiment_keys])
62
63 if not batch:
64 print(
65 f"{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}"
66 )
67
68 del vectors_source, vectors_target, defs_source, defs_target
69
70 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus)
71 common = [
72 word
73 for word in vec.get_feature_names()
74 if word in clean_src_vectors or word in clean_target_vectors
75 ]
76 W_common = []
77 for w in common:
78 if w in clean_src_vectors:
79 W_common.append(np.array(clean_src_vectors[w]))
80 else:
81 W_common.append(np.array(clean_target_vectors[w]))
82
83 if not batch:
84 print(f"{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}")
85
86 W_common = np.array(W_common)
87 W_common = normalize(W_common)
88 vect = TfidfVectorizer(vocabulary=common, dtype=np.double, norm=None)
89 vect.fit(clean_src_corpus + clean_target_corpus)
90 X_train_idf = vect.transform(clean_src_corpus)
91 X_test_idf = vect.transform(clean_target_corpus)
92
93 for metric in runfor:
94 if not batch:
95 print(f"{metric}: {source_lang} - {target_lang}")
96
97 clf = WassersteinRetriever(
98 W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == "snk")
99 )
100 clf.fit(X_train_idf[:instances], np.ones(instances))
101 p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances)
102
103 if not batch:
104 print(f"P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%")
105 else:
106 fields = [
107 f"{source_lang}",
108 f"{target_lang}",
109 f"{instances}",
110 f"{p_at_one}",
111 f"{percentage}",
112 ]
113 with open(f"{metric}_retrieval_result.csv", "a") as f:
114 writer = csv.writer(f)
115 writer.writerow(fields)
116
117
118if __name__ == "__main__":
119
120 parser = argparse.ArgumentParser(description="run retrieval using wmd or snk")
121 parser.add_argument("source_lang", help="source language short name")
122 parser.add_argument("target_lang", help="target language short name")
123 parser.add_argument("source_vector", help="path of the source vector")
124 parser.add_argument("target_vector", help="path of the target vector")
125 parser.add_argument("source_defs", help="path of the source definitions")
126 parser.add_argument("target_defs", help="path of the target definitions")
127 parser.add_argument(
128 "-b",
129 "--batch",
130 action="store_true",
131 help="running in batch (store results in csv) or running a single instance (output the results)",
132 )
133 parser.add_argument(
134 "mode",
135 choices=["all", "wmd", "snk"],
136 default="all",
137 help="which methods to run",
138 )
139 parser.add_argument(
140 "-n",
141 "--instances",
142 help="number of instances in each language to retrieve",
143 default=1000,
144 type=int,
145 )
146
147 args = parser.parse_args()
148
149 main(args)