diff options
Diffstat (limited to 'WMD_matching.py')
| -rw-r--r-- | WMD_matching.py | 64 |
1 files changed, 1 insertions, 63 deletions
diff --git a/WMD_matching.py b/WMD_matching.py index 38dbff4..8581ffe 100644 --- a/WMD_matching.py +++ b/WMD_matching.py | |||
| @@ -6,69 +6,7 @@ import random | |||
| 6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | 6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer |
| 7 | from sklearn.preprocessing import normalize | 7 | from sklearn.preprocessing import normalize |
| 8 | from Wasserstein_Distance import Wasserstein_Matcher | 8 | from Wasserstein_Distance import Wasserstein_Matcher |
| 9 | 9 | from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary | |
| 10 | def load_embeddings(path, dimension=300): | ||
| 11 | """ | ||
| 12 | Loads the embeddings from a word2vec formatted file. | ||
| 13 | word2vec format is one line per word and it's associated embedding | ||
| 14 | (dimension x floating numbers) separated by spaces | ||
| 15 | The first line may or may not include the word count and dimension | ||
| 16 | """ | ||
| 17 | vectors = {} | ||
| 18 | with open(path, mode='r', encoding='utf8') as fp: | ||
| 19 | first_line = fp.readline().rstrip('\n') | ||
| 20 | if first_line.count(' ') == 1: | ||
| 21 | # includes the "word_count dimension" information | ||
| 22 | (word_count, dimension) = map(int, first_line.split()) | ||
| 23 | else: | ||
| 24 | # assume the file only contains vectors | ||
| 25 | fp.seek(0) | ||
| 26 | for line in fp: | ||
| 27 | elems = line.split() | ||
| 28 | vectors[" ".join(elems[:-dimension])] = " ".join(elems[-dimension:]) | ||
| 29 | return vectors | ||
| 30 | |||
| 31 | def clean_corpus_using_embeddings_vocabulary( | ||
| 32 | embeddings_dictionary, | ||
| 33 | corpus, | ||
| 34 | vectors, | ||
| 35 | language, | ||
| 36 | ): | ||
| 37 | ''' | ||
| 38 | Cleans corpus using the dictionary of embeddings. | ||
| 39 | Any word without an associated embedding in the dictionary is ignored. | ||
| 40 | Adds '__target-language' and '__source-language' at the end of the words according to their language. | ||
| 41 | ''' | ||
| 42 | clean_corpus, clean_vectors, keys = [], {}, [] | ||
| 43 | words_we_want = set(embeddings_dictionary) | ||
| 44 | tokenize = MosesTokenizer(language) | ||
| 45 | for key, doc in enumerate(corpus): | ||
| 46 | clean_doc = [] | ||
| 47 | words = tokenize(doc) | ||
| 48 | for word in words: | ||
| 49 | if word in words_we_want: | ||
| 50 | clean_doc.append(word + '__%s' % language) | ||
| 51 | clean_vectors[word + '__%s' % language] = np.array(vectors[word].split()).astype(np.float) | ||
| 52 | if len(clean_doc) > 3 and len(clean_doc) < 25: | ||
| 53 | keys.append(key) | ||
| 54 | clean_corpus.append(' '.join(clean_doc)) | ||
| 55 | tokenize.close() | ||
| 56 | return np.array(clean_corpus), clean_vectors, keys | ||
| 57 | |||
| 58 | def mrr_precision_at_k(golden, preds, k_list=[1,]): | ||
| 59 | """ | ||
| 60 | Calculates Mean Reciprocal Error and Hits@1 == Precision@1 | ||
| 61 | """ | ||
| 62 | my_score = 0 | ||
| 63 | precision_at = np.zeros(len(k_list)) | ||
| 64 | for key, elem in enumerate(golden): | ||
| 65 | if elem in preds[key]: | ||
| 66 | location = np.where(preds[key]==elem)[0][0] | ||
| 67 | my_score += 1/(1+ location) | ||
| 68 | for k_index, k_value in enumerate(k_list): | ||
| 69 | if location < k_value: | ||
| 70 | precision_at[k_index] += 1 | ||
| 71 | return my_score/len(golden), (precision_at/len(golden))[0] | ||
| 72 | 10 | ||
| 73 | def main(args): | 11 | def main(args): |
| 74 | 12 | ||
