diff options
Diffstat (limited to 'Wasserstein_Distance.py')
-rw-r--r-- | Wasserstein_Distance.py | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py index d2a6408..d8d08b8 100644 --- a/Wasserstein_Distance.py +++ b/Wasserstein_Distance.py | |||
@@ -138,3 +138,68 @@ class Wasserstein_Retriever(KNeighborsClassifier): | |||
138 | return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) | 138 | return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) |
139 | 139 | ||
140 | 140 | ||
141 | def load_embeddings(path, dimension=300): | ||
142 | """ | ||
143 | Loads the embeddings from a word2vec formatted file. | ||
144 | word2vec format is one line per word and it's associated embedding | ||
145 | (dimension x floating numbers) separated by spaces | ||
146 | The first line may or may not include the word count and dimension | ||
147 | """ | ||
148 | vectors = {} | ||
149 | with open(path, mode='r', encoding='utf8') as fp: | ||
150 | first_line = fp.readline().rstrip('\n') | ||
151 | if first_line.count(' ') == 1: | ||
152 | # includes the "word_count dimension" information | ||
153 | (word_count, dimension) = map(int, first_line.split()) | ||
154 | else: | ||
155 | # assume the file only contains vectors | ||
156 | fp.seek(0) | ||
157 | for line in fp: | ||
158 | elems = line.split() | ||
159 | vectors[" ".join(elems[:-dimension])] = " ".join(elems[-dimension:]) | ||
160 | return vectors | ||
161 | |||
162 | |||
163 | def clean_corpus_using_embeddings_vocabulary( | ||
164 | embeddings_dictionary, | ||
165 | corpus, | ||
166 | vectors, | ||
167 | language, | ||
168 | ): | ||
169 | ''' | ||
170 | Cleans corpus using the dictionary of embeddings. | ||
171 | Any word without an associated embedding in the dictionary is ignored. | ||
172 | Adds '__target-language' and '__source-language' at the end of the words according to their language. | ||
173 | ''' | ||
174 | clean_corpus, clean_vectors, keys = [], {}, [] | ||
175 | words_we_want = set(embeddings_dictionary) | ||
176 | tokenize = MosesTokenizer(language) | ||
177 | for key, doc in enumerate(corpus): | ||
178 | clean_doc = [] | ||
179 | words = tokenize(doc) | ||
180 | for word in words: | ||
181 | if word in words_we_want: | ||
182 | clean_doc.append(word + '__%s' % language) | ||
183 | clean_vectors[word + '__%s' % language] = np.array(vectors[word].split()).astype(np.float) | ||
184 | if len(clean_doc) > 3 and len(clean_doc) < 25: | ||
185 | keys.append(key) | ||
186 | clean_corpus.append(' '.join(clean_doc)) | ||
187 | tokenize.close() | ||
188 | return np.array(clean_corpus), clean_vectors, keys | ||
189 | |||
190 | |||
191 | def mrr_precision_at_k(golden, preds, k_list=[1,]): | ||
192 | """ | ||
193 | Calculates Mean Reciprocal Error and Hits@1 == Precision@1 | ||
194 | """ | ||
195 | my_score = 0 | ||
196 | precision_at = np.zeros(len(k_list)) | ||
197 | for key, elem in enumerate(golden): | ||
198 | if elem in preds[key]: | ||
199 | location = np.where(preds[key]==elem)[0][0] | ||
200 | my_score += 1/(1+ location) | ||
201 | for k_index, k_value in enumerate(k_list): | ||
202 | if location < k_value: | ||
203 | precision_at[k_index] += 1 | ||
204 | return my_score/len(golden), (precision_at/len(golden))[0] | ||
205 | |||