aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--WMD_matching.py97
-rw-r--r--WMD_retrieval.py97
-rw-r--r--Wasserstein_Distance.py136
3 files changed, 163 insertions, 167 deletions
diff --git a/WMD_matching.py b/WMD_matching.py
index 2755d15..69ea10e 100644
--- a/WMD_matching.py
+++ b/WMD_matching.py
@@ -13,7 +13,7 @@ from Wasserstein_Distance import (WassersteinMatcher,
13 13
14def main(args): 14def main(args):
15 15
16 np.seterr(divide='ignore') # POT has issues with divide by zero errors 16 np.seterr(divide="ignore") # POT has issues with divide by zero errors
17 source_lang = args.source_lang 17 source_lang = args.source_lang
18 target_lang = args.target_lang 18 target_lang = args.target_lang
19 19
@@ -29,32 +29,24 @@ def main(args):
29 mode = args.mode 29 mode = args.mode
30 runfor = list() 30 runfor = list()
31 31
32 if mode == 'all': 32 if mode == "all":
33 runfor.extend(['wmd', 'snk']) 33 runfor.extend(["wmd", "snk"])
34 else: 34 else:
35 runfor.append(mode) 35 runfor.append(mode)
36 36
37 defs_source = [ 37 defs_source = [
38 line.rstrip('\n') 38 line.rstrip("\n") for line in open(source_defs_filename, encoding="utf8")
39 for line in open(source_defs_filename, encoding='utf8')
40 ] 39 ]
41 defs_target = [ 40 defs_target = [
42 line.rstrip('\n') 41 line.rstrip("\n") for line in open(target_defs_filename, encoding="utf8")
43 for line in open(target_defs_filename, encoding='utf8')
44 ] 42 ]
45 43
46 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary( 44 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary(
47 set(vectors_source.keys()), 45 set(vectors_source.keys()), defs_source, vectors_source, source_lang
48 defs_source,
49 vectors_source,
50 source_lang,
51 ) 46 )
52 47
53 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary( 48 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary(
54 set(vectors_target.keys()), 49 set(vectors_target.keys()), defs_target, vectors_target, target_lang
55 defs_target,
56 vectors_target,
57 target_lang,
58 ) 50 )
59 51
60 take = args.instances 52 take = args.instances
@@ -70,14 +62,15 @@ def main(args):
70 62
71 if not batch: 63 if not batch:
72 print( 64 print(
73 f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}' 65 f"{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}"
74 ) 66 )
75 67
76 del vectors_source, vectors_target, defs_source, defs_target 68 del vectors_source, vectors_target, defs_source, defs_target
77 69
78 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus) 70 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus)
79 common = [ 71 common = [
80 word for word in vec.get_feature_names() 72 word
73 for word in vec.get_feature_names()
81 if word in clean_src_vectors or word in clean_target_vectors 74 if word in clean_src_vectors or word in clean_target_vectors
82 ] 75 ]
83 W_common = [] 76 W_common = []
@@ -88,9 +81,7 @@ def main(args):
88 W_common.append(np.array(clean_target_vectors[w])) 81 W_common.append(np.array(clean_target_vectors[w]))
89 82
90 if not batch: 83 if not batch:
91 print( 84 print(f"{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}")
92 f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}'
93 )
94 85
95 W_common = np.array(W_common) 86 W_common = np.array(W_common)
96 W_common = normalize(W_common) 87 W_common = normalize(W_common)
@@ -101,24 +92,25 @@ def main(args):
101 92
102 for metric in runfor: 93 for metric in runfor:
103 if not batch: 94 if not batch:
104 print(f'{metric}: {source_lang} - {target_lang}') 95 print(f"{metric}: {source_lang} - {target_lang}")
105 96
106 clf = WassersteinMatcher(W_embed=W_common, 97 clf = WassersteinMatcher(
107 n_neighbors=5, 98 W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == "snk")
108 n_jobs=14, 99 )
109 sinkhorn=(metric == 'snk'))
110 clf.fit(X_train_idf[:instances], np.ones(instances)) 100 clf.fit(X_train_idf[:instances], np.ones(instances))
111 p_at_one, percentage = clf.align(X_test_idf[:instances], 101 p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances)
112 n_neighbors=instances)
113 102
114 if not batch: 103 if not batch:
115 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') 104 print(f"P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%")
116 else: 105 else:
117 fields = [ 106 fields = [
118 f'{source_lang}', f'{target_lang}', f'{instances}', 107 f"{source_lang}",
119 f'{p_at_one}', f'{percentage}' 108 f"{target_lang}",
109 f"{instances}",
110 f"{p_at_one}",
111 f"{percentage}",
120 ] 112 ]
121 with open(f'{metric}_matching_results.csv', 'a') as f: 113 with open(f"{metric}_matching_results.csv", "a") as f:
122 writer = csv.writer(f) 114 writer = csv.writer(f)
123 writer.writerow(fields) 115 writer.writerow(fields)
124 116
@@ -126,30 +118,33 @@ def main(args):
126if __name__ == "__main__": 118if __name__ == "__main__":
127 119
128 parser = argparse.ArgumentParser( 120 parser = argparse.ArgumentParser(
129 description='matching using wmd and wasserstein distance') 121 description="matching using wmd and wasserstein distance"
130 parser.add_argument('source_lang', help='source language short name') 122 )
131 parser.add_argument('target_lang', help='target language short name') 123 parser.add_argument("source_lang", help="source language short name")
132 parser.add_argument('source_vector', help='path of the source vector') 124 parser.add_argument("target_lang", help="target language short name")
133 parser.add_argument('target_vector', help='path of the target vector') 125 parser.add_argument("source_vector", help="path of the source vector")
134 parser.add_argument('source_defs', help='path of the source definitions') 126 parser.add_argument("target_vector", help="path of the target vector")
135 parser.add_argument('target_defs', help='path of the target definitions') 127 parser.add_argument("source_defs", help="path of the source definitions")
128 parser.add_argument("target_defs", help="path of the target definitions")
136 parser.add_argument( 129 parser.add_argument(
137 '-b', 130 "-b",
138 '--batch', 131 "--batch",
139 action='store_true', 132 action="store_true",
140 help= 133 help="running in batch (store results in csv) or running a single instance (output the results)",
141 'running in batch (store results in csv) or running a single instance (output the results)'
142 ) 134 )
143 parser.add_argument('mode',
144 choices=['all', 'wmd', 'snk'],
145 default='all',
146 help='which methods to run')
147 parser.add_argument( 135 parser.add_argument(
148 '-n', 136 "mode",
149 '--instances', 137 choices=["all", "wmd", "snk"],
150 help='number of instances in each language to retrieve', 138 default="all",
139 help="which methods to run",
140 )
141 parser.add_argument(
142 "-n",
143 "--instances",
144 help="number of instances in each language to retrieve",
151 default=1000, 145 default=1000,
152 type=int) 146 type=int,
147 )
153 148
154 args = parser.parse_args() 149 args = parser.parse_args()
155 150
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
index 02f35be..cb72079 100644
--- a/WMD_retrieval.py
+++ b/WMD_retrieval.py
@@ -13,7 +13,7 @@ from Wasserstein_Distance import (WassersteinRetriever,
13 13
14def main(args): 14def main(args):
15 15
16 np.seterr(divide='ignore') # POT has issues with divide by zero errors 16 np.seterr(divide="ignore") # POT has issues with divide by zero errors
17 source_lang = args.source_lang 17 source_lang = args.source_lang
18 target_lang = args.target_lang 18 target_lang = args.target_lang
19 19
@@ -29,32 +29,24 @@ def main(args):
29 mode = args.mode 29 mode = args.mode
30 runfor = list() 30 runfor = list()
31 31
32 if mode == 'all': 32 if mode == "all":
33 runfor.extend(['wmd', 'snk']) 33 runfor.extend(["wmd", "snk"])
34 else: 34 else:
35 runfor.append(mode) 35 runfor.append(mode)
36 36
37 defs_source = [ 37 defs_source = [
38 line.rstrip('\n') 38 line.rstrip("\n") for line in open(source_defs_filename, encoding="utf8")
39 for line in open(source_defs_filename, encoding='utf8')
40 ] 39 ]
41 defs_target = [ 40 defs_target = [
42 line.rstrip('\n') 41 line.rstrip("\n") for line in open(target_defs_filename, encoding="utf8")
43 for line in open(target_defs_filename, encoding='utf8')
44 ] 42 ]
45 43
46 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary( 44 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary(
47 set(vectors_source.keys()), 45 set(vectors_source.keys()), defs_source, vectors_source, source_lang
48 defs_source,
49 vectors_source,
50 source_lang,
51 ) 46 )
52 47
53 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary( 48 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary(
54 set(vectors_target.keys()), 49 set(vectors_target.keys()), defs_target, vectors_target, target_lang
55 defs_target,
56 vectors_target,
57 target_lang,
58 ) 50 )
59 51
60 take = args.instances 52 take = args.instances
@@ -70,14 +62,15 @@ def main(args):
70 62
71 if not batch: 63 if not batch:
72 print( 64 print(
73 f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}' 65 f"{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}"
74 ) 66 )
75 67
76 del vectors_source, vectors_target, defs_source, defs_target 68 del vectors_source, vectors_target, defs_source, defs_target
77 69
78 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus) 70 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus)
79 common = [ 71 common = [
80 word for word in vec.get_feature_names() 72 word
73 for word in vec.get_feature_names()
81 if word in clean_src_vectors or word in clean_target_vectors 74 if word in clean_src_vectors or word in clean_target_vectors
82 ] 75 ]
83 W_common = [] 76 W_common = []
@@ -88,9 +81,7 @@ def main(args):
88 W_common.append(np.array(clean_target_vectors[w])) 81 W_common.append(np.array(clean_target_vectors[w]))
89 82
90 if not batch: 83 if not batch:
91 print( 84 print(f"{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}")
92 f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}'
93 )
94 85
95 W_common = np.array(W_common) 86 W_common = np.array(W_common)
96 W_common = normalize(W_common) 87 W_common = normalize(W_common)
@@ -101,55 +92,57 @@ def main(args):
101 92
102 for metric in runfor: 93 for metric in runfor:
103 if not batch: 94 if not batch:
104 print(f'{metric}: {source_lang} - {target_lang}') 95 print(f"{metric}: {source_lang} - {target_lang}")
105 96
106 clf = WassersteinRetriever(W_embed=W_common, 97 clf = WassersteinRetriever(
107 n_neighbors=5, 98 W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == "snk")
108 n_jobs=14, 99 )
109 sinkhorn=(metric == 'snk'))
110 clf.fit(X_train_idf[:instances], np.ones(instances)) 100 clf.fit(X_train_idf[:instances], np.ones(instances))
111 p_at_one, percentage = clf.align(X_test_idf[:instances], 101 p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances)
112 n_neighbors=instances)
113 102
114 if not batch: 103 if not batch:
115 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') 104 print(f"P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%")
116 else: 105 else:
117 fields = [ 106 fields = [
118 f'{source_lang}', f'{target_lang}', f'{instances}', 107 f"{source_lang}",
119 f'{p_at_one}', f'{percentage}' 108 f"{target_lang}",
109 f"{instances}",
110 f"{p_at_one}",
111 f"{percentage}",
120 ] 112 ]
121 with open(f'{metric}_retrieval_result.csv', 'a') as f: 113 with open(f"{metric}_retrieval_result.csv", "a") as f:
122 writer = csv.writer(f) 114 writer = csv.writer(f)
123 writer.writerow(fields) 115 writer.writerow(fields)
124 116
125 117
126if __name__ == "__main__": 118if __name__ == "__main__":
127 119
128 parser = argparse.ArgumentParser( 120 parser = argparse.ArgumentParser(description="run retrieval using wmd or snk")
129 description='run retrieval using wmd or snk') 121 parser.add_argument("source_lang", help="source language short name")
130 parser.add_argument('source_lang', help='source language short name') 122 parser.add_argument("target_lang", help="target language short name")
131 parser.add_argument('target_lang', help='target language short name') 123 parser.add_argument("source_vector", help="path of the source vector")
132 parser.add_argument('source_vector', help='path of the source vector') 124 parser.add_argument("target_vector", help="path of the target vector")
133 parser.add_argument('target_vector', help='path of the target vector') 125 parser.add_argument("source_defs", help="path of the source definitions")
134 parser.add_argument('source_defs', help='path of the source definitions') 126 parser.add_argument("target_defs", help="path of the target definitions")
135 parser.add_argument('target_defs', help='path of the target definitions')
136 parser.add_argument( 127 parser.add_argument(
137 '-b', 128 "-b",
138 '--batch', 129 "--batch",
139 action='store_true', 130 action="store_true",
140 help= 131 help="running in batch (store results in csv) or running a single instance (output the results)",
141 'running in batch (store results in csv) or running a single instance (output the results)'
142 ) 132 )
143 parser.add_argument('mode',
144 choices=['all', 'wmd', 'snk'],
145 default='all',
146 help='which methods to run')
147 parser.add_argument( 133 parser.add_argument(
148 '-n', 134 "mode",
149 '--instances', 135 choices=["all", "wmd", "snk"],
150 help='number of instances in each language to retrieve', 136 default="all",
137 help="which methods to run",
138 )
139 parser.add_argument(
140 "-n",
141 "--instances",
142 help="number of instances in each language to retrieve",
151 default=1000, 143 default=1000,
152 type=int) 144 type=int,
145 )
153 146
154 args = parser.parse_args() 147 args = parser.parse_args()
155 148
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py
index 78bf9cf..60991b9 100644
--- a/Wasserstein_Distance.py
+++ b/Wasserstein_Distance.py
@@ -11,17 +11,20 @@ from sklearn.utils import check_array
11 11
12class WassersteinMatcher(KNeighborsClassifier): 12class WassersteinMatcher(KNeighborsClassifier):
13 """ 13 """
14 Implements a nearest neighbors classifier for input distributions using the Wasserstein distance as metric. 14 Source and target distributions are l_1 normalized before computing the Wasserstein
15 Source and target distributions are l_1 normalized before computing the Wasserstein distance. 15 distance. Wasserstein is parametrized by the distances between the individual
16 Wasserstein is parametrized by the distances between the individual points of the distributions. 16 points of the distributions.
17 """ 17 """
18 def __init__(self, 18
19 W_embed, 19 def __init__(
20 n_neighbors=1, 20 self,
21 n_jobs=1, 21 W_embed,
22 verbose=False, 22 n_neighbors=1,
23 sinkhorn=False, 23 n_jobs=1,
24 sinkhorn_reg=0.1): 24 verbose=False,
25 sinkhorn=False,
26 sinkhorn_reg=0.1,
27 ):
25 """ 28 """
26 Initialization of the class. 29 Initialization of the class.
27 Arguments 30 Arguments
@@ -33,10 +36,12 @@ class WassersteinMatcher(KNeighborsClassifier):
33 self.sinkhorn_reg = sinkhorn_reg 36 self.sinkhorn_reg = sinkhorn_reg
34 self.W_embed = W_embed 37 self.W_embed = W_embed
35 self.verbose = verbose 38 self.verbose = verbose
36 super(WassersteinMatcher, self).__init__(n_neighbors=n_neighbors, 39 super(WassersteinMatcher, self).__init__(
37 n_jobs=n_jobs, 40 n_neighbors=n_neighbors,
38 metric='precomputed', 41 n_jobs=n_jobs,
39 algorithm='brute') 42 metric="precomputed",
43 algorithm="brute",
44 )
40 45
41 def _wmd(self, i, row, X_train): 46 def _wmd(self, i, row, X_train):
42 union_idx = np.union1d(X_train[i].indices, row.indices) 47 union_idx = np.union1d(X_train[i].indices, row.indices)
@@ -51,7 +56,7 @@ class WassersteinMatcher(KNeighborsClassifier):
51 W_dist, 56 W_dist,
52 self.sinkhorn_reg, 57 self.sinkhorn_reg,
53 numItermax=50, 58 numItermax=50,
54 method='sinkhorn_stabilized', 59 method="sinkhorn_stabilized",
55 )[0] 60 )[0]
56 else: 61 else:
57 return ot.emd2(bow_i, bow_j, W_dist) 62 return ot.emd2(bow_i, bow_j, W_dist)
@@ -66,27 +71,27 @@ class WassersteinMatcher(KNeighborsClassifier):
66 71
67 if X_train is None: 72 if X_train is None:
68 X_train = self._fit_X 73 X_train = self._fit_X
69 pool = Pool(nodes=self.n_jobs 74 pool = Pool(
70 ) # Parallelization of the calculation of the distances 75 nodes=self.n_jobs
76 ) # Parallelization of the calculation of the distances
71 dist = pool.map(self._wmd_row, X_test) 77 dist = pool.map(self._wmd_row, X_test)
72 return np.array(dist) 78 return np.array(dist)
73 79
74 def fit(self, X, y): # X_train_idf 80 def fit(self, X, y): # X_train_idf
75 X = check_array(X, accept_sparse='csr', 81 X = check_array(X, accept_sparse="csr", copy=True) # check if array is sparse
76 copy=True) # check if array is sparse 82 X = normalize(X, norm="l1", copy=False)
77 X = normalize(X, norm='l1', copy=False)
78 return super(WassersteinMatcher, self).fit(X, y) 83 return super(WassersteinMatcher, self).fit(X, y)
79 84
80 def predict(self, X): 85 def predict(self, X):
81 X = check_array(X, accept_sparse='csr', copy=True) 86 X = check_array(X, accept_sparse="csr", copy=True)
82 X = normalize(X, norm='l1', copy=False) 87 X = normalize(X, norm="l1", copy=False)
83 dist = self._pairwise_wmd(X) 88 dist = self._pairwise_wmd(X)
84 dist = dist * 1000 # for lapjv, small floating point numbers are evil 89 dist = dist * 1000 # for lapjv, small floating point numbers are evil
85 return super(WassersteinMatcher, self).predict(dist) 90 return super(WassersteinMatcher, self).predict(dist)
86 91
87 def kneighbors(self, X, n_neighbors=1): 92 def kneighbors(self, X, n_neighbors=1):
88 X = check_array(X, accept_sparse='csr', copy=True) 93 X = check_array(X, accept_sparse="csr", copy=True)
89 X = normalize(X, norm='l1', copy=False) 94 X = normalize(X, norm="l1", copy=False)
90 dist = self._pairwise_wmd(X) 95 dist = self._pairwise_wmd(X)
91 dist = dist * 1000 # for lapjv, small floating point numbers are evil 96 dist = dist * 1000 # for lapjv, small floating point numbers are evil
92 return lapjv(dist) 97 return lapjv(dist)
@@ -102,19 +107,24 @@ class WassersteinMatcher(KNeighborsClassifier):
102 percentage = p_at_one / n_neighbors * 100 107 percentage = p_at_one / n_neighbors * 100
103 return p_at_one, percentage 108 return p_at_one, percentage
104 109
110
105class WassersteinRetriever(KNeighborsClassifier): 111class WassersteinRetriever(KNeighborsClassifier):
106 """ 112 """
107 Implements a nearest neighbors classifier for input distributions using the Wasserstein distance as metric. 113 Implements a nearest neighbors classifier for input distributions using
108 Source and target distributions are l_1 normalized before computing the Wasserstein distance. 114 the Wasserstein distance as metric. Source and target distributions
109 Wasserstein is parametrized by the distances between the individual points of the distributions. 115 are l_1 normalized before computing the Wasserstein distance. Wasserstein is
116 parametrized by the distances between the individual points of the distributions.
110 """ 117 """
111 def __init__(self, 118
112 W_embed, 119 def __init__(
113 n_neighbors=1, 120 self,
114 n_jobs=1, 121 W_embed,
115 verbose=False, 122 n_neighbors=1,
116 sinkhorn=False, 123 n_jobs=1,
117 sinkhorn_reg=0.1): 124 verbose=False,
125 sinkhorn=False,
126 sinkhorn_reg=0.1,
127 ):
118 """ 128 """
119 Initialization of the class. 129 Initialization of the class.
120 Arguments 130 Arguments
@@ -126,10 +136,12 @@ class WassersteinRetriever(KNeighborsClassifier):
126 self.sinkhorn_reg = sinkhorn_reg 136 self.sinkhorn_reg = sinkhorn_reg
127 self.W_embed = W_embed 137 self.W_embed = W_embed
128 self.verbose = verbose 138 self.verbose = verbose
129 super(WassersteinRetriever, self).__init__(n_neighbors=n_neighbors, 139 super(WassersteinRetriever, self).__init__(
130 n_jobs=n_jobs, 140 n_neighbors=n_neighbors,
131 metric='precomputed', 141 n_jobs=n_jobs,
132 algorithm='brute') 142 metric="precomputed",
143 algorithm="brute",
144 )
133 145
134 def _wmd(self, i, row, X_train): 146 def _wmd(self, i, row, X_train):
135 union_idx = np.union1d(X_train[i].indices, row.indices) 147 union_idx = np.union1d(X_train[i].indices, row.indices)
@@ -144,7 +156,7 @@ class WassersteinRetriever(KNeighborsClassifier):
144 W_dist, 156 W_dist,
145 self.sinkhorn_reg, 157 self.sinkhorn_reg,
146 numItermax=50, 158 numItermax=50,
147 method='sinkhorn_stabilized', 159 method="sinkhorn_stabilized",
148 )[0] 160 )[0]
149 else: 161 else:
150 return ot.emd2(bow_i, bow_j, W_dist) 162 return ot.emd2(bow_i, bow_j, W_dist)
@@ -164,19 +176,19 @@ class WassersteinRetriever(KNeighborsClassifier):
164 return np.array(dist) 176 return np.array(dist)
165 177
166 def fit(self, X, y): 178 def fit(self, X, y):
167 X = check_array(X, accept_sparse='csr', copy=True) 179 X = check_array(X, accept_sparse="csr", copy=True)
168 X = normalize(X, norm='l1', copy=False) 180 X = normalize(X, norm="l1", copy=False)
169 return super(WassersteinRetriever, self).fit(X, y) 181 return super(WassersteinRetriever, self).fit(X, y)
170 182
171 def predict(self, X): 183 def predict(self, X):
172 X = check_array(X, accept_sparse='csr', copy=True) 184 X = check_array(X, accept_sparse="csr", copy=True)
173 X = normalize(X, norm='l1', copy=False) 185 X = normalize(X, norm="l1", copy=False)
174 dist = self._pairwise_wmd(X) 186 dist = self._pairwise_wmd(X)
175 return super(WassersteinRetriever, self).predict(dist) 187 return super(WassersteinRetriever, self).predict(dist)
176 188
177 def kneighbors(self, X, n_neighbors=1): 189 def kneighbors(self, X, n_neighbors=1):
178 X = check_array(X, accept_sparse='csr', copy=True) 190 X = check_array(X, accept_sparse="csr", copy=True)
179 X = normalize(X, norm='l1', copy=False) 191 X = normalize(X, norm="l1", copy=False)
180 dist = self._pairwise_wmd(X) 192 dist = self._pairwise_wmd(X)
181 return super(WassersteinRetriever, self).kneighbors(dist, n_neighbors) 193 return super(WassersteinRetriever, self).kneighbors(dist, n_neighbors)
182 194
@@ -199,9 +211,9 @@ def load_embeddings(path, dimension=300):
199 The first line may or may not include the word count and dimension 211 The first line may or may not include the word count and dimension
200 """ 212 """
201 vectors = {} 213 vectors = {}
202 with open(path, mode='r', encoding='utf8') as fp: 214 with open(path, mode="r", encoding="utf8") as fp:
203 first_line = fp.readline().rstrip('\n') 215 first_line = fp.readline().rstrip("\n")
204 if first_line.count(' ') == 1: 216 if first_line.count(" ") == 1:
205 # includes the "word_count dimension" information 217 # includes the "word_count dimension" information
206 (_, dimension) = map(int, first_line.split()) 218 (_, dimension) = map(int, first_line.split())
207 else: 219 else:
@@ -209,22 +221,19 @@ def load_embeddings(path, dimension=300):
209 fp.seek(0) 221 fp.seek(0)
210 for line in fp: 222 for line in fp:
211 elems = line.split() 223 elems = line.split()
212 vectors[" ".join(elems[:-dimension])] = " ".join( 224 vectors[" ".join(elems[:-dimension])] = " ".join(elems[-dimension:])
213 elems[-dimension:])
214 return vectors 225 return vectors
215 226
216 227
217def clean_corpus_using_embeddings_vocabulary( 228def clean_corpus_using_embeddings_vocabulary(
218 embeddings_dictionary, 229 embeddings_dictionary, corpus, vectors, language
219 corpus,
220 vectors,
221 language,
222): 230):
223 ''' 231 """
224 Cleans corpus using the dictionary of embeddings. 232 Cleans corpus using the dictionary of embeddings.
225 Any word without an associated embedding in the dictionary is ignored. 233 Any word without an associated embedding in the dictionary is ignored.
226 Adds '__target-language' and '__source-language' at the end of the words according to their language. 234 Adds '__target-language' and '__source-language' at the end
227 ''' 235 of the words according to their language.
236 """
228 clean_corpus, clean_vectors, keys = [], {}, [] 237 clean_corpus, clean_vectors, keys = [], {}, []
229 words_we_want = set(embeddings_dictionary) 238 words_we_want = set(embeddings_dictionary)
230 tokenize = MosesTokenizer(language) 239 tokenize = MosesTokenizer(language)
@@ -233,19 +242,18 @@ def clean_corpus_using_embeddings_vocabulary(
233 words = tokenize(doc) 242 words = tokenize(doc)
234 for word in words: 243 for word in words:
235 if word in words_we_want: 244 if word in words_we_want:
236 clean_doc.append(word + '__%s' % language) 245 clean_doc.append(word + "__%s" % language)
237 clean_vectors[word + '__%s' % language] = np.array( 246 clean_vectors[word + "__%s" % language] = np.array(
238 vectors[word].split()).astype(np.float) 247 vectors[word].split()
248 ).astype(np.float)
239 if len(clean_doc) > 3 and len(clean_doc) < 25: 249 if len(clean_doc) > 3 and len(clean_doc) < 25:
240 keys.append(key) 250 keys.append(key)
241 clean_corpus.append(' '.join(clean_doc)) 251 clean_corpus.append(" ".join(clean_doc))
242 tokenize.close() 252 tokenize.close()
243 return np.array(clean_corpus), clean_vectors, keys 253 return np.array(clean_corpus), clean_vectors, keys
244 254
245 255
246def mrr_precision_at_k(golden, preds, k_list=[ 256def mrr_precision_at_k(golden, preds, k_list=[1]):
247 1,
248]):
249 """ 257 """
250 Calculates Mean Reciprocal Error and Hits@1 == Precision@1 258 Calculates Mean Reciprocal Error and Hits@1 == Precision@1
251 """ 259 """