diff options
author | Yigit Sever | 2019-09-21 15:58:19 +0300 |
---|---|---|
committer | Yigit Sever | 2019-09-21 15:58:19 +0300 |
commit | 778f3936edf3104660d23a88fe23da46c42709a4 (patch) | |
tree | 9a90e327d3f2b4e5c9cf5855932cf26fdd74d139 /WMD_matching.py | |
parent | 4cb6986480def9b0c91fb46e276839c60f96aa49 (diff) | |
download | Evaluating-Dictionary-Alignment-778f3936edf3104660d23a88fe23da46c42709a4.tar.gz Evaluating-Dictionary-Alignment-778f3936edf3104660d23a88fe23da46c42709a4.tar.bz2 Evaluating-Dictionary-Alignment-778f3936edf3104660d23a88fe23da46c42709a4.zip |
Move functions to centralize
Diffstat (limited to 'WMD_matching.py')
-rw-r--r-- | WMD_matching.py | 64 |
1 files changed, 1 insertions, 63 deletions
diff --git a/WMD_matching.py b/WMD_matching.py index 38dbff4..8581ffe 100644 --- a/WMD_matching.py +++ b/WMD_matching.py | |||
@@ -6,69 +6,7 @@ import random | |||
6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | 6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer |
7 | from sklearn.preprocessing import normalize | 7 | from sklearn.preprocessing import normalize |
8 | from Wasserstein_Distance import Wasserstein_Matcher | 8 | from Wasserstein_Distance import Wasserstein_Matcher |
9 | 9 | from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary | |
10 | def load_embeddings(path, dimension=300): | ||
11 | """ | ||
12 | Loads the embeddings from a word2vec formatted file. | ||
13 | word2vec format is one line per word and it's associated embedding | ||
14 | (dimension x floating numbers) separated by spaces | ||
15 | The first line may or may not include the word count and dimension | ||
16 | """ | ||
17 | vectors = {} | ||
18 | with open(path, mode='r', encoding='utf8') as fp: | ||
19 | first_line = fp.readline().rstrip('\n') | ||
20 | if first_line.count(' ') == 1: | ||
21 | # includes the "word_count dimension" information | ||
22 | (word_count, dimension) = map(int, first_line.split()) | ||
23 | else: | ||
24 | # assume the file only contains vectors | ||
25 | fp.seek(0) | ||
26 | for line in fp: | ||
27 | elems = line.split() | ||
28 | vectors[" ".join(elems[:-dimension])] = " ".join(elems[-dimension:]) | ||
29 | return vectors | ||
30 | |||
31 | def clean_corpus_using_embeddings_vocabulary( | ||
32 | embeddings_dictionary, | ||
33 | corpus, | ||
34 | vectors, | ||
35 | language, | ||
36 | ): | ||
37 | ''' | ||
38 | Cleans corpus using the dictionary of embeddings. | ||
39 | Any word without an associated embedding in the dictionary is ignored. | ||
40 | Adds '__target-language' and '__source-language' at the end of the words according to their language. | ||
41 | ''' | ||
42 | clean_corpus, clean_vectors, keys = [], {}, [] | ||
43 | words_we_want = set(embeddings_dictionary) | ||
44 | tokenize = MosesTokenizer(language) | ||
45 | for key, doc in enumerate(corpus): | ||
46 | clean_doc = [] | ||
47 | words = tokenize(doc) | ||
48 | for word in words: | ||
49 | if word in words_we_want: | ||
50 | clean_doc.append(word + '__%s' % language) | ||
51 | clean_vectors[word + '__%s' % language] = np.array(vectors[word].split()).astype(np.float) | ||
52 | if len(clean_doc) > 3 and len(clean_doc) < 25: | ||
53 | keys.append(key) | ||
54 | clean_corpus.append(' '.join(clean_doc)) | ||
55 | tokenize.close() | ||
56 | return np.array(clean_corpus), clean_vectors, keys | ||
57 | |||
58 | def mrr_precision_at_k(golden, preds, k_list=[1,]): | ||
59 | """ | ||
60 | Calculates Mean Reciprocal Error and Hits@1 == Precision@1 | ||
61 | """ | ||
62 | my_score = 0 | ||
63 | precision_at = np.zeros(len(k_list)) | ||
64 | for key, elem in enumerate(golden): | ||
65 | if elem in preds[key]: | ||
66 | location = np.where(preds[key]==elem)[0][0] | ||
67 | my_score += 1/(1+ location) | ||
68 | for k_index, k_value in enumerate(k_list): | ||
69 | if location < k_value: | ||
70 | precision_at[k_index] += 1 | ||
71 | return my_score/len(golden), (precision_at/len(golden))[0] | ||
72 | 10 | ||
73 | def main(args): | 11 | def main(args): |
74 | 12 | ||