diff options
Diffstat (limited to 'Wasserstein_Distance.py')
-rw-r--r-- | Wasserstein_Distance.py | 57 |
1 files changed, 33 insertions, 24 deletions
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py index 161c13c..78bf9cf 100644 --- a/Wasserstein_Distance.py +++ b/Wasserstein_Distance.py | |||
@@ -1,16 +1,15 @@ | |||
1 | import numpy as np | 1 | import numpy as np |
2 | from sklearn.metrics import euclidean_distances | ||
3 | from sklearn.neighbors import KNeighborsClassifier | ||
4 | from sklearn.preprocessing import normalize | ||
5 | from sklearn.utils import check_array | ||
6 | |||
7 | import ot | 2 | import ot |
8 | from lapjv import lapjv | 3 | from lapjv import lapjv |
9 | from mosestokenizer import MosesTokenizer | 4 | from mosestokenizer import MosesTokenizer |
10 | from pathos.multiprocessing import ProcessingPool as Pool | 5 | from pathos.multiprocessing import ProcessingPool as Pool |
6 | from sklearn.metrics import euclidean_distances | ||
7 | from sklearn.neighbors import KNeighborsClassifier | ||
8 | from sklearn.preprocessing import normalize | ||
9 | from sklearn.utils import check_array | ||
11 | 10 | ||
12 | 11 | ||
13 | class Wasserstein_Matcher(KNeighborsClassifier): | 12 | class WassersteinMatcher(KNeighborsClassifier): |
14 | """ | 13 | """ |
15 | Implements a nearest neighbors classifier for input distributions using the Wasserstein distance as metric. | 14 | Implements a nearest neighbors classifier for input distributions using the Wasserstein distance as metric. |
16 | Source and target distributions are l_1 normalized before computing the Wasserstein distance. | 15 | Source and target distributions are l_1 normalized before computing the Wasserstein distance. |
@@ -34,10 +33,10 @@ class Wasserstein_Matcher(KNeighborsClassifier): | |||
34 | self.sinkhorn_reg = sinkhorn_reg | 33 | self.sinkhorn_reg = sinkhorn_reg |
35 | self.W_embed = W_embed | 34 | self.W_embed = W_embed |
36 | self.verbose = verbose | 35 | self.verbose = verbose |
37 | super(Wasserstein_Matcher, self).__init__(n_neighbors=n_neighbors, | 36 | super(WassersteinMatcher, self).__init__(n_neighbors=n_neighbors, |
38 | n_jobs=n_jobs, | 37 | n_jobs=n_jobs, |
39 | metric='precomputed', | 38 | metric='precomputed', |
40 | algorithm='brute') | 39 | algorithm='brute') |
41 | 40 | ||
42 | def _wmd(self, i, row, X_train): | 41 | def _wmd(self, i, row, X_train): |
43 | union_idx = np.union1d(X_train[i].indices, row.indices) | 42 | union_idx = np.union1d(X_train[i].indices, row.indices) |
@@ -76,25 +75,34 @@ class Wasserstein_Matcher(KNeighborsClassifier): | |||
76 | X = check_array(X, accept_sparse='csr', | 75 | X = check_array(X, accept_sparse='csr', |
77 | copy=True) # check if array is sparse | 76 | copy=True) # check if array is sparse |
78 | X = normalize(X, norm='l1', copy=False) | 77 | X = normalize(X, norm='l1', copy=False) |
79 | return super(Wasserstein_Matcher, self).fit( | 78 | return super(WassersteinMatcher, self).fit(X, y) |
80 | X, y) # X_train_idf, np_ones(document collection size) | ||
81 | 79 | ||
82 | def predict(self, X): | 80 | def predict(self, X): |
83 | X = check_array(X, accept_sparse='csr', copy=True) | 81 | X = check_array(X, accept_sparse='csr', copy=True) |
84 | X = normalize(X, norm='l1', copy=False) | 82 | X = normalize(X, norm='l1', copy=False) |
85 | dist = self._pairwise_wmd(X) | 83 | dist = self._pairwise_wmd(X) |
86 | dist = dist * 1000 # for lapjv, small floating point numbers are evil | 84 | dist = dist * 1000 # for lapjv, small floating point numbers are evil |
87 | return super(Wasserstein_Matcher, self).predict(dist) | 85 | return super(WassersteinMatcher, self).predict(dist) |
88 | 86 | ||
89 | def kneighbors(self, X, n_neighbors=1): # X : X_train_idf | 87 | def kneighbors(self, X, n_neighbors=1): |
90 | X = check_array(X, accept_sparse='csr', copy=True) | 88 | X = check_array(X, accept_sparse='csr', copy=True) |
91 | X = normalize(X, norm='l1', copy=False) | 89 | X = normalize(X, norm='l1', copy=False) |
92 | dist = self._pairwise_wmd(X) | 90 | dist = self._pairwise_wmd(X) |
93 | dist = dist * 1000 # for lapjv, small floating point numbers are evil | 91 | dist = dist * 1000 # for lapjv, small floating point numbers are evil |
94 | return lapjv(dist) # and here is the matching part | 92 | return lapjv(dist) |
95 | 93 | ||
94 | def align(self, X, n_neighbors=1): | ||
95 | """ Wrapper function over kneighbors to return | ||
96 | precision at one and percentage values | ||
96 | 97 | ||
97 | class Wasserstein_Retriever(KNeighborsClassifier): | 98 | """ |
99 | row_ind, col_ind, _ = self.kneighbors(X, n_neighbors) | ||
100 | result = zip(row_ind, col_ind) | ||
101 | p_at_one = len([x for x, y in result if x == y]) | ||
102 | percentage = p_at_one / n_neighbors * 100 | ||
103 | return p_at_one, percentage | ||
104 | |||
105 | class WassersteinRetriever(KNeighborsClassifier): | ||
98 | """ | 106 | """ |
99 | Implements a nearest neighbors classifier for input distributions using the Wasserstein distance as metric. | 107 | Implements a nearest neighbors classifier for input distributions using the Wasserstein distance as metric. |
100 | Source and target distributions are l_1 normalized before computing the Wasserstein distance. | 108 | Source and target distributions are l_1 normalized before computing the Wasserstein distance. |
@@ -118,7 +126,7 @@ class Wasserstein_Retriever(KNeighborsClassifier): | |||
118 | self.sinkhorn_reg = sinkhorn_reg | 126 | self.sinkhorn_reg = sinkhorn_reg |
119 | self.W_embed = W_embed | 127 | self.W_embed = W_embed |
120 | self.verbose = verbose | 128 | self.verbose = verbose |
121 | super(Wasserstein_Retriever, self).__init__(n_neighbors=n_neighbors, | 129 | super(WassersteinRetriever, self).__init__(n_neighbors=n_neighbors, |
122 | n_jobs=n_jobs, | 130 | n_jobs=n_jobs, |
123 | metric='precomputed', | 131 | metric='precomputed', |
124 | algorithm='brute') | 132 | algorithm='brute') |
@@ -158,23 +166,22 @@ class Wasserstein_Retriever(KNeighborsClassifier): | |||
158 | def fit(self, X, y): | 166 | def fit(self, X, y): |
159 | X = check_array(X, accept_sparse='csr', copy=True) | 167 | X = check_array(X, accept_sparse='csr', copy=True) |
160 | X = normalize(X, norm='l1', copy=False) | 168 | X = normalize(X, norm='l1', copy=False) |
161 | return super(Wasserstein_Retriever, self).fit(X, y) | 169 | return super(WassersteinRetriever, self).fit(X, y) |
162 | 170 | ||
163 | def predict(self, X): | 171 | def predict(self, X): |
164 | X = check_array(X, accept_sparse='csr', copy=True) | 172 | X = check_array(X, accept_sparse='csr', copy=True) |
165 | X = normalize(X, norm='l1', copy=False) | 173 | X = normalize(X, norm='l1', copy=False) |
166 | dist = self._pairwise_wmd(X) | 174 | dist = self._pairwise_wmd(X) |
167 | return super(Wasserstein_Retriever, self).predict(dist) | 175 | return super(WassersteinRetriever, self).predict(dist) |
168 | 176 | ||
169 | def kneighbors(self, X, n_neighbors=1): | 177 | def kneighbors(self, X, n_neighbors=1): |
170 | X = check_array(X, accept_sparse='csr', copy=True) | 178 | X = check_array(X, accept_sparse='csr', copy=True) |
171 | X = normalize(X, norm='l1', copy=False) | 179 | X = normalize(X, norm='l1', copy=False) |
172 | dist = self._pairwise_wmd(X) | 180 | dist = self._pairwise_wmd(X) |
173 | return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) | 181 | return super(WassersteinRetriever, self).kneighbors(dist, n_neighbors) |
174 | 182 | ||
175 | def align(self, X, n_neighbors=1): | 183 | def align(self, X, n_neighbors=1): |
176 | """ | 184 | """ Wrapper function over kneighbors to return |
177 | Wrapper function over kneighbors to return | ||
178 | precision at one and percentage values | 185 | precision at one and percentage values |
179 | 186 | ||
180 | """ | 187 | """ |
@@ -196,7 +203,7 @@ def load_embeddings(path, dimension=300): | |||
196 | first_line = fp.readline().rstrip('\n') | 203 | first_line = fp.readline().rstrip('\n') |
197 | if first_line.count(' ') == 1: | 204 | if first_line.count(' ') == 1: |
198 | # includes the "word_count dimension" information | 205 | # includes the "word_count dimension" information |
199 | (word_count, dimension) = map(int, first_line.split()) | 206 | (_, dimension) = map(int, first_line.split()) |
200 | else: | 207 | else: |
201 | # assume the file only contains vectors | 208 | # assume the file only contains vectors |
202 | fp.seek(0) | 209 | fp.seek(0) |
@@ -236,7 +243,9 @@ def clean_corpus_using_embeddings_vocabulary( | |||
236 | return np.array(clean_corpus), clean_vectors, keys | 243 | return np.array(clean_corpus), clean_vectors, keys |
237 | 244 | ||
238 | 245 | ||
239 | def mrr_precision_at_k(golden, preds, k_list=[1,]): | 246 | def mrr_precision_at_k(golden, preds, k_list=[ |
247 | 1, | ||
248 | ]): | ||
240 | """ | 249 | """ |
241 | Calculates Mean Reciprocal Error and Hits@1 == Precision@1 | 250 | Calculates Mean Reciprocal Error and Hits@1 == Precision@1 |
242 | """ | 251 | """ |