aboutsummaryrefslogtreecommitdiffstats
path: root/Wasserstein_Distance.py
diff options
context:
space:
mode:
Diffstat (limited to 'Wasserstein_Distance.py')
-rw-r--r--Wasserstein_Distance.py136
1 files changed, 72 insertions, 64 deletions
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py
index 78bf9cf..60991b9 100644
--- a/Wasserstein_Distance.py
+++ b/Wasserstein_Distance.py
@@ -11,17 +11,20 @@ from sklearn.utils import check_array
11 11
12class WassersteinMatcher(KNeighborsClassifier): 12class WassersteinMatcher(KNeighborsClassifier):
13 """ 13 """
14 Implements a nearest neighbors classifier for input distributions using the Wasserstein distance as metric. 14 Source and target distributions are l_1 normalized before computing the Wasserstein
15 Source and target distributions are l_1 normalized before computing the Wasserstein distance. 15 distance. Wasserstein is parametrized by the distances between the individual
16 Wasserstein is parametrized by the distances between the individual points of the distributions. 16 points of the distributions.
17 """ 17 """
18 def __init__(self, 18
19 W_embed, 19 def __init__(
20 n_neighbors=1, 20 self,
21 n_jobs=1, 21 W_embed,
22 verbose=False, 22 n_neighbors=1,
23 sinkhorn=False, 23 n_jobs=1,
24 sinkhorn_reg=0.1): 24 verbose=False,
25 sinkhorn=False,
26 sinkhorn_reg=0.1,
27 ):
25 """ 28 """
26 Initialization of the class. 29 Initialization of the class.
27 Arguments 30 Arguments
@@ -33,10 +36,12 @@ class WassersteinMatcher(KNeighborsClassifier):
33 self.sinkhorn_reg = sinkhorn_reg 36 self.sinkhorn_reg = sinkhorn_reg
34 self.W_embed = W_embed 37 self.W_embed = W_embed
35 self.verbose = verbose 38 self.verbose = verbose
36 super(WassersteinMatcher, self).__init__(n_neighbors=n_neighbors, 39 super(WassersteinMatcher, self).__init__(
37 n_jobs=n_jobs, 40 n_neighbors=n_neighbors,
38 metric='precomputed', 41 n_jobs=n_jobs,
39 algorithm='brute') 42 metric="precomputed",
43 algorithm="brute",
44 )
40 45
41 def _wmd(self, i, row, X_train): 46 def _wmd(self, i, row, X_train):
42 union_idx = np.union1d(X_train[i].indices, row.indices) 47 union_idx = np.union1d(X_train[i].indices, row.indices)
@@ -51,7 +56,7 @@ class WassersteinMatcher(KNeighborsClassifier):
51 W_dist, 56 W_dist,
52 self.sinkhorn_reg, 57 self.sinkhorn_reg,
53 numItermax=50, 58 numItermax=50,
54 method='sinkhorn_stabilized', 59 method="sinkhorn_stabilized",
55 )[0] 60 )[0]
56 else: 61 else:
57 return ot.emd2(bow_i, bow_j, W_dist) 62 return ot.emd2(bow_i, bow_j, W_dist)
@@ -66,27 +71,27 @@ class WassersteinMatcher(KNeighborsClassifier):
66 71
67 if X_train is None: 72 if X_train is None:
68 X_train = self._fit_X 73 X_train = self._fit_X
69 pool = Pool(nodes=self.n_jobs 74 pool = Pool(
70 ) # Parallelization of the calculation of the distances 75 nodes=self.n_jobs
76 ) # Parallelization of the calculation of the distances
71 dist = pool.map(self._wmd_row, X_test) 77 dist = pool.map(self._wmd_row, X_test)
72 return np.array(dist) 78 return np.array(dist)
73 79
74 def fit(self, X, y): # X_train_idf 80 def fit(self, X, y): # X_train_idf
75 X = check_array(X, accept_sparse='csr', 81 X = check_array(X, accept_sparse="csr", copy=True) # check if array is sparse
76 copy=True) # check if array is sparse 82 X = normalize(X, norm="l1", copy=False)
77 X = normalize(X, norm='l1', copy=False)
78 return super(WassersteinMatcher, self).fit(X, y) 83 return super(WassersteinMatcher, self).fit(X, y)
79 84
80 def predict(self, X): 85 def predict(self, X):
81 X = check_array(X, accept_sparse='csr', copy=True) 86 X = check_array(X, accept_sparse="csr", copy=True)
82 X = normalize(X, norm='l1', copy=False) 87 X = normalize(X, norm="l1", copy=False)
83 dist = self._pairwise_wmd(X) 88 dist = self._pairwise_wmd(X)
84 dist = dist * 1000 # for lapjv, small floating point numbers are evil 89 dist = dist * 1000 # for lapjv, small floating point numbers are evil
85 return super(WassersteinMatcher, self).predict(dist) 90 return super(WassersteinMatcher, self).predict(dist)
86 91
87 def kneighbors(self, X, n_neighbors=1): 92 def kneighbors(self, X, n_neighbors=1):
88 X = check_array(X, accept_sparse='csr', copy=True) 93 X = check_array(X, accept_sparse="csr", copy=True)
89 X = normalize(X, norm='l1', copy=False) 94 X = normalize(X, norm="l1", copy=False)
90 dist = self._pairwise_wmd(X) 95 dist = self._pairwise_wmd(X)
91 dist = dist * 1000 # for lapjv, small floating point numbers are evil 96 dist = dist * 1000 # for lapjv, small floating point numbers are evil
92 return lapjv(dist) 97 return lapjv(dist)
@@ -102,19 +107,24 @@ class WassersteinMatcher(KNeighborsClassifier):
102 percentage = p_at_one / n_neighbors * 100 107 percentage = p_at_one / n_neighbors * 100
103 return p_at_one, percentage 108 return p_at_one, percentage
104 109
110
105class WassersteinRetriever(KNeighborsClassifier): 111class WassersteinRetriever(KNeighborsClassifier):
106 """ 112 """
107 Implements a nearest neighbors classifier for input distributions using the Wasserstein distance as metric. 113 Implements a nearest neighbors classifier for input distributions using
108 Source and target distributions are l_1 normalized before computing the Wasserstein distance. 114 the Wasserstein distance as metric. Source and target distributions
109 Wasserstein is parametrized by the distances between the individual points of the distributions. 115 are l_1 normalized before computing the Wasserstein distance. Wasserstein is
116 parametrized by the distances between the individual points of the distributions.
110 """ 117 """
111 def __init__(self, 118
112 W_embed, 119 def __init__(
113 n_neighbors=1, 120 self,
114 n_jobs=1, 121 W_embed,
115 verbose=False, 122 n_neighbors=1,
116 sinkhorn=False, 123 n_jobs=1,
117 sinkhorn_reg=0.1): 124 verbose=False,
125 sinkhorn=False,
126 sinkhorn_reg=0.1,
127 ):
118 """ 128 """
119 Initialization of the class. 129 Initialization of the class.
120 Arguments 130 Arguments
@@ -126,10 +136,12 @@ class WassersteinRetriever(KNeighborsClassifier):
126 self.sinkhorn_reg = sinkhorn_reg 136 self.sinkhorn_reg = sinkhorn_reg
127 self.W_embed = W_embed 137 self.W_embed = W_embed
128 self.verbose = verbose 138 self.verbose = verbose
129 super(WassersteinRetriever, self).__init__(n_neighbors=n_neighbors, 139 super(WassersteinRetriever, self).__init__(
130 n_jobs=n_jobs, 140 n_neighbors=n_neighbors,
131 metric='precomputed', 141 n_jobs=n_jobs,
132 algorithm='brute') 142 metric="precomputed",
143 algorithm="brute",
144 )
133 145
134 def _wmd(self, i, row, X_train): 146 def _wmd(self, i, row, X_train):
135 union_idx = np.union1d(X_train[i].indices, row.indices) 147 union_idx = np.union1d(X_train[i].indices, row.indices)
@@ -144,7 +156,7 @@ class WassersteinRetriever(KNeighborsClassifier):
144 W_dist, 156 W_dist,
145 self.sinkhorn_reg, 157 self.sinkhorn_reg,
146 numItermax=50, 158 numItermax=50,
147 method='sinkhorn_stabilized', 159 method="sinkhorn_stabilized",
148 )[0] 160 )[0]
149 else: 161 else:
150 return ot.emd2(bow_i, bow_j, W_dist) 162 return ot.emd2(bow_i, bow_j, W_dist)
@@ -164,19 +176,19 @@ class WassersteinRetriever(KNeighborsClassifier):
164 return np.array(dist) 176 return np.array(dist)
165 177
166 def fit(self, X, y): 178 def fit(self, X, y):
167 X = check_array(X, accept_sparse='csr', copy=True) 179 X = check_array(X, accept_sparse="csr", copy=True)
168 X = normalize(X, norm='l1', copy=False) 180 X = normalize(X, norm="l1", copy=False)
169 return super(WassersteinRetriever, self).fit(X, y) 181 return super(WassersteinRetriever, self).fit(X, y)
170 182
171 def predict(self, X): 183 def predict(self, X):
172 X = check_array(X, accept_sparse='csr', copy=True) 184 X = check_array(X, accept_sparse="csr", copy=True)
173 X = normalize(X, norm='l1', copy=False) 185 X = normalize(X, norm="l1", copy=False)
174 dist = self._pairwise_wmd(X) 186 dist = self._pairwise_wmd(X)
175 return super(WassersteinRetriever, self).predict(dist) 187 return super(WassersteinRetriever, self).predict(dist)
176 188
177 def kneighbors(self, X, n_neighbors=1): 189 def kneighbors(self, X, n_neighbors=1):
178 X = check_array(X, accept_sparse='csr', copy=True) 190 X = check_array(X, accept_sparse="csr", copy=True)
179 X = normalize(X, norm='l1', copy=False) 191 X = normalize(X, norm="l1", copy=False)
180 dist = self._pairwise_wmd(X) 192 dist = self._pairwise_wmd(X)
181 return super(WassersteinRetriever, self).kneighbors(dist, n_neighbors) 193 return super(WassersteinRetriever, self).kneighbors(dist, n_neighbors)
182 194
@@ -199,9 +211,9 @@ def load_embeddings(path, dimension=300):
199 The first line may or may not include the word count and dimension 211 The first line may or may not include the word count and dimension
200 """ 212 """
201 vectors = {} 213 vectors = {}
202 with open(path, mode='r', encoding='utf8') as fp: 214 with open(path, mode="r", encoding="utf8") as fp:
203 first_line = fp.readline().rstrip('\n') 215 first_line = fp.readline().rstrip("\n")
204 if first_line.count(' ') == 1: 216 if first_line.count(" ") == 1:
205 # includes the "word_count dimension" information 217 # includes the "word_count dimension" information
206 (_, dimension) = map(int, first_line.split()) 218 (_, dimension) = map(int, first_line.split())
207 else: 219 else:
@@ -209,22 +221,19 @@ def load_embeddings(path, dimension=300):
209 fp.seek(0) 221 fp.seek(0)
210 for line in fp: 222 for line in fp:
211 elems = line.split() 223 elems = line.split()
212 vectors[" ".join(elems[:-dimension])] = " ".join( 224 vectors[" ".join(elems[:-dimension])] = " ".join(elems[-dimension:])
213 elems[-dimension:])
214 return vectors 225 return vectors
215 226
216 227
217def clean_corpus_using_embeddings_vocabulary( 228def clean_corpus_using_embeddings_vocabulary(
218 embeddings_dictionary, 229 embeddings_dictionary, corpus, vectors, language
219 corpus,
220 vectors,
221 language,
222): 230):
223 ''' 231 """
224 Cleans corpus using the dictionary of embeddings. 232 Cleans corpus using the dictionary of embeddings.
225 Any word without an associated embedding in the dictionary is ignored. 233 Any word without an associated embedding in the dictionary is ignored.
226 Adds '__target-language' and '__source-language' at the end of the words according to their language. 234 Adds '__target-language' and '__source-language' at the end
227 ''' 235 of the words according to their language.
236 """
228 clean_corpus, clean_vectors, keys = [], {}, [] 237 clean_corpus, clean_vectors, keys = [], {}, []
229 words_we_want = set(embeddings_dictionary) 238 words_we_want = set(embeddings_dictionary)
230 tokenize = MosesTokenizer(language) 239 tokenize = MosesTokenizer(language)
@@ -233,19 +242,18 @@ def clean_corpus_using_embeddings_vocabulary(
233 words = tokenize(doc) 242 words = tokenize(doc)
234 for word in words: 243 for word in words:
235 if word in words_we_want: 244 if word in words_we_want:
236 clean_doc.append(word + '__%s' % language) 245 clean_doc.append(word + "__%s" % language)
237 clean_vectors[word + '__%s' % language] = np.array( 246 clean_vectors[word + "__%s" % language] = np.array(
238 vectors[word].split()).astype(np.float) 247 vectors[word].split()
248 ).astype(np.float)
239 if len(clean_doc) > 3 and len(clean_doc) < 25: 249 if len(clean_doc) > 3 and len(clean_doc) < 25:
240 keys.append(key) 250 keys.append(key)
241 clean_corpus.append(' '.join(clean_doc)) 251 clean_corpus.append(" ".join(clean_doc))
242 tokenize.close() 252 tokenize.close()
243 return np.array(clean_corpus), clean_vectors, keys 253 return np.array(clean_corpus), clean_vectors, keys
244 254
245 255
246def mrr_precision_at_k(golden, preds, k_list=[ 256def mrr_precision_at_k(golden, preds, k_list=[1]):
247 1,
248]):
249 """ 257 """
250 Calculates Mean Reciprocal Error and Hits@1 == Precision@1 258 Calculates Mean Reciprocal Error and Hits@1 == Precision@1
251 """ 259 """