diff options
-rw-r--r-- | WMD_matching.py | 64 | ||||
-rw-r--r-- | Wasserstein_Distance.py | 65 |
2 files changed, 66 insertions, 63 deletions
diff --git a/WMD_matching.py b/WMD_matching.py index 38dbff4..8581ffe 100644 --- a/WMD_matching.py +++ b/WMD_matching.py | |||
@@ -6,69 +6,7 @@ import random | |||
6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | 6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer |
7 | from sklearn.preprocessing import normalize | 7 | from sklearn.preprocessing import normalize |
8 | from Wasserstein_Distance import Wasserstein_Matcher | 8 | from Wasserstein_Distance import Wasserstein_Matcher |
9 | 9 | from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary | |
10 | def load_embeddings(path, dimension=300): | ||
11 | """ | ||
12 | Loads the embeddings from a word2vec formatted file. | ||
13 | word2vec format is one line per word and it's associated embedding | ||
14 | (dimension x floating numbers) separated by spaces | ||
15 | The first line may or may not include the word count and dimension | ||
16 | """ | ||
17 | vectors = {} | ||
18 | with open(path, mode='r', encoding='utf8') as fp: | ||
19 | first_line = fp.readline().rstrip('\n') | ||
20 | if first_line.count(' ') == 1: | ||
21 | # includes the "word_count dimension" information | ||
22 | (word_count, dimension) = map(int, first_line.split()) | ||
23 | else: | ||
24 | # assume the file only contains vectors | ||
25 | fp.seek(0) | ||
26 | for line in fp: | ||
27 | elems = line.split() | ||
28 | vectors[" ".join(elems[:-dimension])] = " ".join(elems[-dimension:]) | ||
29 | return vectors | ||
30 | |||
31 | def clean_corpus_using_embeddings_vocabulary( | ||
32 | embeddings_dictionary, | ||
33 | corpus, | ||
34 | vectors, | ||
35 | language, | ||
36 | ): | ||
37 | ''' | ||
38 | Cleans corpus using the dictionary of embeddings. | ||
39 | Any word without an associated embedding in the dictionary is ignored. | ||
40 | Adds '__target-language' and '__source-language' at the end of the words according to their language. | ||
41 | ''' | ||
42 | clean_corpus, clean_vectors, keys = [], {}, [] | ||
43 | words_we_want = set(embeddings_dictionary) | ||
44 | tokenize = MosesTokenizer(language) | ||
45 | for key, doc in enumerate(corpus): | ||
46 | clean_doc = [] | ||
47 | words = tokenize(doc) | ||
48 | for word in words: | ||
49 | if word in words_we_want: | ||
50 | clean_doc.append(word + '__%s' % language) | ||
51 | clean_vectors[word + '__%s' % language] = np.array(vectors[word].split()).astype(np.float) | ||
52 | if len(clean_doc) > 3 and len(clean_doc) < 25: | ||
53 | keys.append(key) | ||
54 | clean_corpus.append(' '.join(clean_doc)) | ||
55 | tokenize.close() | ||
56 | return np.array(clean_corpus), clean_vectors, keys | ||
57 | |||
58 | def mrr_precision_at_k(golden, preds, k_list=[1,]): | ||
59 | """ | ||
60 | Calculates Mean Reciprocal Error and Hits@1 == Precision@1 | ||
61 | """ | ||
62 | my_score = 0 | ||
63 | precision_at = np.zeros(len(k_list)) | ||
64 | for key, elem in enumerate(golden): | ||
65 | if elem in preds[key]: | ||
66 | location = np.where(preds[key]==elem)[0][0] | ||
67 | my_score += 1/(1+ location) | ||
68 | for k_index, k_value in enumerate(k_list): | ||
69 | if location < k_value: | ||
70 | precision_at[k_index] += 1 | ||
71 | return my_score/len(golden), (precision_at/len(golden))[0] | ||
72 | 10 | ||
73 | def main(args): | 11 | def main(args): |
74 | 12 | ||
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py index d2a6408..d8d08b8 100644 --- a/Wasserstein_Distance.py +++ b/Wasserstein_Distance.py | |||
@@ -138,3 +138,68 @@ class Wasserstein_Retriever(KNeighborsClassifier): | |||
138 | return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) | 138 | return super(Wasserstein_Retriever, self).kneighbors(dist, n_neighbors) |
139 | 139 | ||
140 | 140 | ||
141 | def load_embeddings(path, dimension=300): | ||
142 | """ | ||
143 | Loads the embeddings from a word2vec formatted file. | ||
144 | word2vec format is one line per word and it's associated embedding | ||
145 | (dimension x floating numbers) separated by spaces | ||
146 | The first line may or may not include the word count and dimension | ||
147 | """ | ||
148 | vectors = {} | ||
149 | with open(path, mode='r', encoding='utf8') as fp: | ||
150 | first_line = fp.readline().rstrip('\n') | ||
151 | if first_line.count(' ') == 1: | ||
152 | # includes the "word_count dimension" information | ||
153 | (word_count, dimension) = map(int, first_line.split()) | ||
154 | else: | ||
155 | # assume the file only contains vectors | ||
156 | fp.seek(0) | ||
157 | for line in fp: | ||
158 | elems = line.split() | ||
159 | vectors[" ".join(elems[:-dimension])] = " ".join(elems[-dimension:]) | ||
160 | return vectors | ||
161 | |||
162 | |||
163 | def clean_corpus_using_embeddings_vocabulary( | ||
164 | embeddings_dictionary, | ||
165 | corpus, | ||
166 | vectors, | ||
167 | language, | ||
168 | ): | ||
169 | ''' | ||
170 | Cleans corpus using the dictionary of embeddings. | ||
171 | Any word without an associated embedding in the dictionary is ignored. | ||
172 | Adds '__target-language' and '__source-language' at the end of the words according to their language. | ||
173 | ''' | ||
174 | clean_corpus, clean_vectors, keys = [], {}, [] | ||
175 | words_we_want = set(embeddings_dictionary) | ||
176 | tokenize = MosesTokenizer(language) | ||
177 | for key, doc in enumerate(corpus): | ||
178 | clean_doc = [] | ||
179 | words = tokenize(doc) | ||
180 | for word in words: | ||
181 | if word in words_we_want: | ||
182 | clean_doc.append(word + '__%s' % language) | ||
183 | clean_vectors[word + '__%s' % language] = np.array(vectors[word].split()).astype(np.float) | ||
184 | if len(clean_doc) > 3 and len(clean_doc) < 25: | ||
185 | keys.append(key) | ||
186 | clean_corpus.append(' '.join(clean_doc)) | ||
187 | tokenize.close() | ||
188 | return np.array(clean_corpus), clean_vectors, keys | ||
189 | |||
190 | |||
191 | def mrr_precision_at_k(golden, preds, k_list=[1,]): | ||
192 | """ | ||
193 | Calculates Mean Reciprocal Error and Hits@1 == Precision@1 | ||
194 | """ | ||
195 | my_score = 0 | ||
196 | precision_at = np.zeros(len(k_list)) | ||
197 | for key, elem in enumerate(golden): | ||
198 | if elem in preds[key]: | ||
199 | location = np.where(preds[key]==elem)[0][0] | ||
200 | my_score += 1/(1+ location) | ||
201 | for k_index, k_value in enumerate(k_list): | ||
202 | if location < k_value: | ||
203 | precision_at[k_index] += 1 | ||
204 | return my_score/len(golden), (precision_at/len(golden))[0] | ||
205 | |||