aboutsummaryrefslogtreecommitdiffstats
path: root/WMD_retrieval.py
diff options
context:
space:
mode:
authorYigit Sever2019-09-22 01:33:24 +0300
committerYigit Sever2019-09-22 01:33:24 +0300
commit2936635892e17031c37facfd2115e8cfd6633222 (patch)
treee284cfb01c2c4a84a9f94cfd528bbc7a57b5d19f /WMD_retrieval.py
parent3a924c24d167a4411b19d1038c59639f06f2ba6b (diff)
downloadEvaluating-Dictionary-Alignment-2936635892e17031c37facfd2115e8cfd6633222.tar.gz
Evaluating-Dictionary-Alignment-2936635892e17031c37facfd2115e8cfd6633222.tar.bz2
Evaluating-Dictionary-Alignment-2936635892e17031c37facfd2115e8cfd6633222.zip
Introduce linter, stylize
Diffstat (limited to 'WMD_retrieval.py')
-rw-r--r--WMD_retrieval.py112
1 files changed, 73 insertions, 39 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
index f32372f..3328023 100644
--- a/WMD_retrieval.py
+++ b/WMD_retrieval.py
@@ -1,15 +1,19 @@
1import argparse 1import argparse
2import numpy as np 2import csv
3import random 3import random
4
5import numpy as np
4from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
5from sklearn.preprocessing import normalize 7from sklearn.preprocessing import normalize
6from Wasserstein_Distance import Wasserstein_Retriever 8
7from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary, mrr_precision_at_k 9from Wasserstein_Distance import (Wasserstein_Retriever,
8import csv 10 clean_corpus_using_embeddings_vocabulary,
11 load_embeddings)
12
9 13
10def main(args): 14def main(args):
11 15
12 np.seterr(divide='ignore') # POT has issues with divide by zero errors 16 np.seterr(divide='ignore') # POT has issues with divide by zero errors
13 source_lang = args.source_lang 17 source_lang = args.source_lang
14 target_lang = args.target_lang 18 target_lang = args.target_lang
15 19
@@ -25,32 +29,38 @@ def main(args):
25 mode = args.mode 29 mode = args.mode
26 runfor = list() 30 runfor = list()
27 31
28 if (mode == 'all'): 32 if mode == 'all':
29 runfor.extend(['wmd','snk']) 33 runfor.extend(['wmd', 'snk'])
30 else: 34 else:
31 runfor.append(mode) 35 runfor.append(mode)
32 36
33 defs_source = [line.rstrip('\n') for line in open(source_defs_filename, encoding='utf8')] 37 defs_source = [
34 defs_target = [line.rstrip('\n') for line in open(target_defs_filename, encoding='utf8')] 38 line.rstrip('\n')
39 for line in open(source_defs_filename, encoding='utf8')
40 ]
41 defs_target = [
42 line.rstrip('\n')
43 for line in open(target_defs_filename, encoding='utf8')
44 ]
35 45
36 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary( 46 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary(
37 set(vectors_source.keys()), 47 set(vectors_source.keys()),
38 defs_source, 48 defs_source,
39 vectors_source, 49 vectors_source,
40 source_lang, 50 source_lang,
41 ) 51 )
42 52
43 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary( 53 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary(
44 set(vectors_target.keys()), 54 set(vectors_target.keys()),
45 defs_target, 55 defs_target,
46 vectors_target, 56 vectors_target,
47 target_lang, 57 target_lang,
48 ) 58 )
49 59
50 take = args.instances 60 take = args.instances
51 61
52 common_keys = set(src_keys).intersection(set(target_keys)) 62 common_keys = set(src_keys).intersection(set(target_keys))
53 take = min(len(common_keys), take) # you can't sample more than length 63 take = min(len(common_keys), take) # you can't sample more than length
54 experiment_keys = random.sample(common_keys, take) 64 experiment_keys = random.sample(common_keys, take)
55 65
56 instances = len(experiment_keys) 66 instances = len(experiment_keys)
@@ -58,13 +68,18 @@ def main(args):
58 clean_src_corpus = list(clean_src_corpus[experiment_keys]) 68 clean_src_corpus = list(clean_src_corpus[experiment_keys])
59 clean_target_corpus = list(clean_target_corpus[experiment_keys]) 69 clean_target_corpus = list(clean_target_corpus[experiment_keys])
60 70
61 if (not batch): 71 if not batch:
62 print(f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}') 72 print(
73 f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}'
74 )
63 75
64 del vectors_source, vectors_target, defs_source, defs_target 76 del vectors_source, vectors_target, defs_source, defs_target
65 77
66 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus) 78 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus)
67 common = [word for word in vec.get_feature_names() if word in clean_src_vectors or word in clean_target_vectors] 79 common = [
80 word for word in vec.get_feature_names()
81 if word in clean_src_vectors or word in clean_target_vectors
82 ]
68 W_common = [] 83 W_common = []
69 for w in common: 84 for w in common:
70 if w in clean_src_vectors: 85 if w in clean_src_vectors:
@@ -72,8 +87,10 @@ def main(args):
72 else: 87 else:
73 W_common.append(np.array(clean_target_vectors[w])) 88 W_common.append(np.array(clean_target_vectors[w]))
74 89
75 if (not batch): 90 if not batch:
76 print(f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}') 91 print(
92 f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}'
93 )
77 94
78 W_common = np.array(W_common) 95 W_common = np.array(W_common)
79 W_common = normalize(W_common) 96 W_common = normalize(W_common)
@@ -82,26 +99,28 @@ def main(args):
82 X_train_idf = vect.transform(clean_src_corpus) 99 X_train_idf = vect.transform(clean_src_corpus)
83 X_test_idf = vect.transform(clean_target_corpus) 100 X_test_idf = vect.transform(clean_target_corpus)
84 101
85 vect_tf = CountVectorizer(vocabulary=common, dtype=np.double)
86 vect_tf.fit(clean_src_corpus + clean_target_corpus)
87 X_train_tf = vect_tf.transform(clean_src_corpus)
88 X_test_tf = vect_tf.transform(clean_target_corpus)
89
90 for metric in runfor: 102 for metric in runfor:
91 if (not batch): 103 if not batch:
92 print(f'{metric} - tfidf: {source_lang} - {target_lang}') 104 print(f'{metric} - tfidf: {source_lang} - {target_lang}')
93 105
94 clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) 106 clf = Wasserstein_Retriever(W_embed=W_common,
107 n_neighbors=5,
108 n_jobs=14,
109 sinkhorn=(metric == 'snk'))
95 clf.fit(X_train_idf[:instances], np.ones(instances)) 110 clf.fit(X_train_idf[:instances], np.ones(instances))
96 # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) 111 # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
97 # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) 112 # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds)
98 # percentage = p_at_one * 100 113 # percentage = p_at_one * 100
99 p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances) 114 p_at_one, percentage = clf.align(X_test_idf[:instances],
115 n_neighbors=instances)
100 116
101 if (not batch): 117 if not batch:
102 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') 118 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%')
103 else: 119 else:
104 fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{p_at_one}', f'{percentage}'] 120 fields = [
121 f'{source_lang}', f'{target_lang}', f'{instances}',
122 f'{p_at_one}', f'{percentage}'
123 ]
105 with open(f'{metric}_retrieval_result.csv', 'a') as f: 124 with open(f'{metric}_retrieval_result.csv', 'a') as f:
106 writer = csv.writer(f) 125 writer = csv.writer(f)
107 writer.writerow(fields) 126 writer.writerow(fields)
@@ -109,16 +128,31 @@ def main(args):
109 128
110if __name__ == "__main__": 129if __name__ == "__main__":
111 130
112 parser = argparse.ArgumentParser(description='run retrieval using wmd or snk') 131 parser = argparse.ArgumentParser(
132 description='run retrieval using wmd or snk')
113 parser.add_argument('source_lang', help='source language short name') 133 parser.add_argument('source_lang', help='source language short name')
114 parser.add_argument('target_lang', help='target language short name') 134 parser.add_argument('target_lang', help='target language short name')
115 parser.add_argument('source_vector', help='path of the source vector') 135 parser.add_argument('source_vector', help='path of the source vector')
116 parser.add_argument('target_vector', help='path of the target vector') 136 parser.add_argument('target_vector', help='path of the target vector')
117 parser.add_argument('source_defs', help='path of the source definitions') 137 parser.add_argument('source_defs', help='path of the source definitions')
118 parser.add_argument('target_defs', help='path of the target definitions') 138 parser.add_argument('target_defs', help='path of the target definitions')
119 parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') 139 parser.add_argument(
120 parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') 140 '-b',
121 parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) 141 '--batch',
142 action='store_true',
143 help=
144 'running in batch (store results in csv) or running a single instance (output the results)'
145 )
146 parser.add_argument('mode',
147 choices=['all', 'wmd', 'snk'],
148 default='all',
149 help='which methods to run')
150 parser.add_argument(
151 '-n',
152 '--instances',
153 help='number of instances in each language to retrieve',
154 default=1000,
155 type=int)
122 156
123 args = parser.parse_args() 157 args = parser.parse_args()
124 158