diff options
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r-- | WMD_retrieval.py | 112 |
1 files changed, 73 insertions, 39 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py index f32372f..3328023 100644 --- a/WMD_retrieval.py +++ b/WMD_retrieval.py | |||
@@ -1,15 +1,19 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import numpy as np | 2 | import csv |
3 | import random | 3 | import random |
4 | |||
5 | import numpy as np | ||
4 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | 6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer |
5 | from sklearn.preprocessing import normalize | 7 | from sklearn.preprocessing import normalize |
6 | from Wasserstein_Distance import Wasserstein_Retriever | 8 | |
7 | from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary, mrr_precision_at_k | 9 | from Wasserstein_Distance import (Wasserstein_Retriever, |
8 | import csv | 10 | clean_corpus_using_embeddings_vocabulary, |
11 | load_embeddings) | ||
12 | |||
9 | 13 | ||
10 | def main(args): | 14 | def main(args): |
11 | 15 | ||
12 | np.seterr(divide='ignore') # POT has issues with divide by zero errors | 16 | np.seterr(divide='ignore') # POT has issues with divide by zero errors |
13 | source_lang = args.source_lang | 17 | source_lang = args.source_lang |
14 | target_lang = args.target_lang | 18 | target_lang = args.target_lang |
15 | 19 | ||
@@ -25,32 +29,38 @@ def main(args): | |||
25 | mode = args.mode | 29 | mode = args.mode |
26 | runfor = list() | 30 | runfor = list() |
27 | 31 | ||
28 | if (mode == 'all'): | 32 | if mode == 'all': |
29 | runfor.extend(['wmd','snk']) | 33 | runfor.extend(['wmd', 'snk']) |
30 | else: | 34 | else: |
31 | runfor.append(mode) | 35 | runfor.append(mode) |
32 | 36 | ||
33 | defs_source = [line.rstrip('\n') for line in open(source_defs_filename, encoding='utf8')] | 37 | defs_source = [ |
34 | defs_target = [line.rstrip('\n') for line in open(target_defs_filename, encoding='utf8')] | 38 | line.rstrip('\n') |
39 | for line in open(source_defs_filename, encoding='utf8') | ||
40 | ] | ||
41 | defs_target = [ | ||
42 | line.rstrip('\n') | ||
43 | for line in open(target_defs_filename, encoding='utf8') | ||
44 | ] | ||
35 | 45 | ||
36 | clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary( | 46 | clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary( |
37 | set(vectors_source.keys()), | 47 | set(vectors_source.keys()), |
38 | defs_source, | 48 | defs_source, |
39 | vectors_source, | 49 | vectors_source, |
40 | source_lang, | 50 | source_lang, |
41 | ) | 51 | ) |
42 | 52 | ||
43 | clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary( | 53 | clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary( |
44 | set(vectors_target.keys()), | 54 | set(vectors_target.keys()), |
45 | defs_target, | 55 | defs_target, |
46 | vectors_target, | 56 | vectors_target, |
47 | target_lang, | 57 | target_lang, |
48 | ) | 58 | ) |
49 | 59 | ||
50 | take = args.instances | 60 | take = args.instances |
51 | 61 | ||
52 | common_keys = set(src_keys).intersection(set(target_keys)) | 62 | common_keys = set(src_keys).intersection(set(target_keys)) |
53 | take = min(len(common_keys), take) # you can't sample more than length | 63 | take = min(len(common_keys), take) # you can't sample more than length |
54 | experiment_keys = random.sample(common_keys, take) | 64 | experiment_keys = random.sample(common_keys, take) |
55 | 65 | ||
56 | instances = len(experiment_keys) | 66 | instances = len(experiment_keys) |
@@ -58,13 +68,18 @@ def main(args): | |||
58 | clean_src_corpus = list(clean_src_corpus[experiment_keys]) | 68 | clean_src_corpus = list(clean_src_corpus[experiment_keys]) |
59 | clean_target_corpus = list(clean_target_corpus[experiment_keys]) | 69 | clean_target_corpus = list(clean_target_corpus[experiment_keys]) |
60 | 70 | ||
61 | if (not batch): | 71 | if not batch: |
62 | print(f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}') | 72 | print( |
73 | f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}' | ||
74 | ) | ||
63 | 75 | ||
64 | del vectors_source, vectors_target, defs_source, defs_target | 76 | del vectors_source, vectors_target, defs_source, defs_target |
65 | 77 | ||
66 | vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus) | 78 | vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus) |
67 | common = [word for word in vec.get_feature_names() if word in clean_src_vectors or word in clean_target_vectors] | 79 | common = [ |
80 | word for word in vec.get_feature_names() | ||
81 | if word in clean_src_vectors or word in clean_target_vectors | ||
82 | ] | ||
68 | W_common = [] | 83 | W_common = [] |
69 | for w in common: | 84 | for w in common: |
70 | if w in clean_src_vectors: | 85 | if w in clean_src_vectors: |
@@ -72,8 +87,10 @@ def main(args): | |||
72 | else: | 87 | else: |
73 | W_common.append(np.array(clean_target_vectors[w])) | 88 | W_common.append(np.array(clean_target_vectors[w])) |
74 | 89 | ||
75 | if (not batch): | 90 | if not batch: |
76 | print(f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}') | 91 | print( |
92 | f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}' | ||
93 | ) | ||
77 | 94 | ||
78 | W_common = np.array(W_common) | 95 | W_common = np.array(W_common) |
79 | W_common = normalize(W_common) | 96 | W_common = normalize(W_common) |
@@ -82,26 +99,28 @@ def main(args): | |||
82 | X_train_idf = vect.transform(clean_src_corpus) | 99 | X_train_idf = vect.transform(clean_src_corpus) |
83 | X_test_idf = vect.transform(clean_target_corpus) | 100 | X_test_idf = vect.transform(clean_target_corpus) |
84 | 101 | ||
85 | vect_tf = CountVectorizer(vocabulary=common, dtype=np.double) | ||
86 | vect_tf.fit(clean_src_corpus + clean_target_corpus) | ||
87 | X_train_tf = vect_tf.transform(clean_src_corpus) | ||
88 | X_test_tf = vect_tf.transform(clean_target_corpus) | ||
89 | |||
90 | for metric in runfor: | 102 | for metric in runfor: |
91 | if (not batch): | 103 | if not batch: |
92 | print(f'{metric} - tfidf: {source_lang} - {target_lang}') | 104 | print(f'{metric} - tfidf: {source_lang} - {target_lang}') |
93 | 105 | ||
94 | clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) | 106 | clf = Wasserstein_Retriever(W_embed=W_common, |
107 | n_neighbors=5, | ||
108 | n_jobs=14, | ||
109 | sinkhorn=(metric == 'snk')) | ||
95 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 110 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
96 | # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 111 | # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
97 | # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) | 112 | # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) |
98 | # percentage = p_at_one * 100 | 113 | # percentage = p_at_one * 100 |
99 | p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances) | 114 | p_at_one, percentage = clf.align(X_test_idf[:instances], |
115 | n_neighbors=instances) | ||
100 | 116 | ||
101 | if (not batch): | 117 | if not batch: |
102 | print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') | 118 | print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') |
103 | else: | 119 | else: |
104 | fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{p_at_one}', f'{percentage}'] | 120 | fields = [ |
121 | f'{source_lang}', f'{target_lang}', f'{instances}', | ||
122 | f'{p_at_one}', f'{percentage}' | ||
123 | ] | ||
105 | with open(f'{metric}_retrieval_result.csv', 'a') as f: | 124 | with open(f'{metric}_retrieval_result.csv', 'a') as f: |
106 | writer = csv.writer(f) | 125 | writer = csv.writer(f) |
107 | writer.writerow(fields) | 126 | writer.writerow(fields) |
@@ -109,16 +128,31 @@ def main(args): | |||
109 | 128 | ||
110 | if __name__ == "__main__": | 129 | if __name__ == "__main__": |
111 | 130 | ||
112 | parser = argparse.ArgumentParser(description='run retrieval using wmd or snk') | 131 | parser = argparse.ArgumentParser( |
132 | description='run retrieval using wmd or snk') | ||
113 | parser.add_argument('source_lang', help='source language short name') | 133 | parser.add_argument('source_lang', help='source language short name') |
114 | parser.add_argument('target_lang', help='target language short name') | 134 | parser.add_argument('target_lang', help='target language short name') |
115 | parser.add_argument('source_vector', help='path of the source vector') | 135 | parser.add_argument('source_vector', help='path of the source vector') |
116 | parser.add_argument('target_vector', help='path of the target vector') | 136 | parser.add_argument('target_vector', help='path of the target vector') |
117 | parser.add_argument('source_defs', help='path of the source definitions') | 137 | parser.add_argument('source_defs', help='path of the source definitions') |
118 | parser.add_argument('target_defs', help='path of the target definitions') | 138 | parser.add_argument('target_defs', help='path of the target definitions') |
119 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') | 139 | parser.add_argument( |
120 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') | 140 | '-b', |
121 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) | 141 | '--batch', |
142 | action='store_true', | ||
143 | help= | ||
144 | 'running in batch (store results in csv) or running a single instance (output the results)' | ||
145 | ) | ||
146 | parser.add_argument('mode', | ||
147 | choices=['all', 'wmd', 'snk'], | ||
148 | default='all', | ||
149 | help='which methods to run') | ||
150 | parser.add_argument( | ||
151 | '-n', | ||
152 | '--instances', | ||
153 | help='number of instances in each language to retrieve', | ||
154 | default=1000, | ||
155 | type=int) | ||
122 | 156 | ||
123 | args = parser.parse_args() | 157 | args = parser.parse_args() |
124 | 158 | ||