aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--WMD_matching.py114
-rw-r--r--WMD_retrieval.py112
-rw-r--r--Wasserstein_Distance.py109
3 files changed, 219 insertions, 116 deletions
diff --git a/WMD_matching.py b/WMD_matching.py
index 59b64f9..ea496b8 100644
--- a/WMD_matching.py
+++ b/WMD_matching.py
@@ -1,15 +1,19 @@
1import argparse 1import argparse
2import numpy as np 2import csv
3import random 3import random
4
5import numpy as np
4from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
5from sklearn.preprocessing import normalize 7from sklearn.preprocessing import normalize
6from Wasserstein_Distance import Wasserstein_Matcher 8
7from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary 9from Wasserstein_Distance import (Wasserstein_Matcher,
8import csv 10 clean_corpus_using_embeddings_vocabulary,
11 load_embeddings)
12
9 13
10def main(args): 14def main(args):
11 15
12 np.seterr(divide='ignore') # POT has issues with divide by zero errors 16 np.seterr(divide='ignore') # POT has issues with divide by zero errors
13 source_lang = args.source_lang 17 source_lang = args.source_lang
14 target_lang = args.target_lang 18 target_lang = args.target_lang
15 19
@@ -25,32 +29,38 @@ def main(args):
25 mode = args.mode 29 mode = args.mode
26 runfor = list() 30 runfor = list()
27 31
28 if (mode == 'all'): 32 if mode == 'all':
29 runfor.extend(['wmd','snk']) 33 runfor.extend(['wmd', 'snk'])
30 else: 34 else:
31 runfor.append(mode) 35 runfor.append(mode)
32 36
33 defs_source = [line.rstrip('\n') for line in open(source_defs_filename, encoding='utf8')] 37 defs_source = [
34 defs_target = [line.rstrip('\n') for line in open(target_defs_filename, encoding='utf8')] 38 line.rstrip('\n')
39 for line in open(source_defs_filename, encoding='utf8')
40 ]
41 defs_target = [
42 line.rstrip('\n')
43 for line in open(target_defs_filename, encoding='utf8')
44 ]
35 45
36 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary( 46 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary(
37 set(vectors_source.keys()), 47 set(vectors_source.keys()),
38 defs_source, 48 defs_source,
39 vectors_source, 49 vectors_source,
40 source_lang, 50 source_lang,
41 ) 51 )
42 52
43 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary( 53 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary(
44 set(vectors_target.keys()), 54 set(vectors_target.keys()),
45 defs_target, 55 defs_target,
46 vectors_target, 56 vectors_target,
47 target_lang, 57 target_lang,
48 ) 58 )
49 59
50 take = args.instances 60 take = args.instances
51 61
52 common_keys = set(src_keys).intersection(set(target_keys)) 62 common_keys = set(src_keys).intersection(set(target_keys))
53 take = min(len(common_keys), take) # you can't sample more than length 63 take = min(len(common_keys), take) # you can't sample more than length
54 experiment_keys = random.sample(common_keys, take) 64 experiment_keys = random.sample(common_keys, take)
55 65
56 instances = len(experiment_keys) 66 instances = len(experiment_keys)
@@ -58,13 +68,18 @@ def main(args):
58 clean_src_corpus = list(clean_src_corpus[experiment_keys]) 68 clean_src_corpus = list(clean_src_corpus[experiment_keys])
59 clean_target_corpus = list(clean_target_corpus[experiment_keys]) 69 clean_target_corpus = list(clean_target_corpus[experiment_keys])
60 70
61 if (not batch): 71 if not batch:
62 print(f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}') 72 print(
73 f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}'
74 )
63 75
64 del vectors_source, vectors_target, defs_source, defs_target 76 del vectors_source, vectors_target, defs_source, defs_target
65 77
66 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus) 78 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus)
67 common = [word for word in vec.get_feature_names() if word in clean_src_vectors or word in clean_target_vectors] 79 common = [
80 word for word in vec.get_feature_names()
81 if word in clean_src_vectors or word in clean_target_vectors
82 ]
68 W_common = [] 83 W_common = []
69 for w in common: 84 for w in common:
70 if w in clean_src_vectors: 85 if w in clean_src_vectors:
@@ -72,8 +87,10 @@ def main(args):
72 else: 87 else:
73 W_common.append(np.array(clean_target_vectors[w])) 88 W_common.append(np.array(clean_target_vectors[w]))
74 89
75 if (not batch): 90 if not batch:
76 print(f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}') 91 print(
92 f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}'
93 )
77 94
78 W_common = np.array(W_common) 95 W_common = np.array(W_common)
79 W_common = normalize(W_common) 96 W_common = normalize(W_common)
@@ -82,26 +99,28 @@ def main(args):
82 X_train_idf = vect.transform(clean_src_corpus) 99 X_train_idf = vect.transform(clean_src_corpus)
83 X_test_idf = vect.transform(clean_target_corpus) 100 X_test_idf = vect.transform(clean_target_corpus)
84 101
85 vect_tf = CountVectorizer(vocabulary=common, dtype=np.double)
86 vect_tf.fit(clean_src_corpus + clean_target_corpus)
87 X_train_tf = vect_tf.transform(clean_src_corpus)
88 X_test_tf = vect_tf.transform(clean_target_corpus)
89
90 for metric in runfor: 102 for metric in runfor:
91 if (not batch): 103 if not batch:
92 print(f'{metric}: {source_lang} - {target_lang}') 104 print(f'{metric}: {source_lang} - {target_lang}')
93 105
94 clf = Wasserstein_Matcher(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) 106 clf = Wasserstein_Matcher(W_embed=W_common,
107 n_neighbors=5,
108 n_jobs=14,
109 sinkhorn=(metric == 'snk'))
95 clf.fit(X_train_idf[:instances], np.ones(instances)) 110 clf.fit(X_train_idf[:instances], np.ones(instances))
96 row_ind, col_ind, a = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) 111 row_ind, col_ind, _ = clf.kneighbors(X_test_idf[:instances],
112 n_neighbors=instances)
97 result = zip(row_ind, col_ind) 113 result = zip(row_ind, col_ind)
98 p_at_one = len([x for x,y in result if x == y]) 114 p_at_one = len([x for x, y in result if x == y])
99 percentage = p_at_one / instances * 100 115 percentage = p_at_one / instances * 100
100 116
101 if (not batch): 117 if not batch:
102 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') 118 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%')
103 else: 119 else:
104 fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{p_at_one}', f'{percentage}'] 120 fields = [
121 f'{source_lang}', f'{target_lang}', f'{instances}',
122 f'{p_at_one}', f'{percentage}'
123 ]
105 with open(f'{metric}_matching_results.csv', 'a') as f: 124 with open(f'{metric}_matching_results.csv', 'a') as f:
106 writer = csv.writer(f) 125 writer = csv.writer(f)
107 writer.writerow(fields) 126 writer.writerow(fields)
@@ -109,16 +128,31 @@ def main(args):
109 128
110if __name__ == "__main__": 129if __name__ == "__main__":
111 130
112 parser = argparse.ArgumentParser(description='matching using wmd and wasserstein distance') 131 parser = argparse.ArgumentParser(
132 description='matching using wmd and wasserstein distance')
113 parser.add_argument('source_lang', help='source language short name') 133 parser.add_argument('source_lang', help='source language short name')
114 parser.add_argument('target_lang', help='target language short name') 134 parser.add_argument('target_lang', help='target language short name')
115 parser.add_argument('source_vector', help='path of the source vector') 135 parser.add_argument('source_vector', help='path of the source vector')
116 parser.add_argument('target_vector', help='path of the target vector') 136 parser.add_argument('target_vector', help='path of the target vector')
117 parser.add_argument('source_defs', help='path of the source definitions') 137 parser.add_argument('source_defs', help='path of the source definitions')
118 parser.add_argument('target_defs', help='path of the target definitions') 138 parser.add_argument('target_defs', help='path of the target definitions')
119 parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') 139 parser.add_argument(
120 parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') 140 '-b',
121 parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) 141 '--batch',
142 action='store_true',
143 help=
144 'running in batch (store results in csv) or running a single instance (output the results)'
145 )
146 parser.add_argument('mode',
147 choices=['all', 'wmd', 'snk'],
148 default='all',
149 help='which methods to run')
150 parser.add_argument(
151 '-n',
152 '--instances',
153 help='number of instances in each language to retrieve',
154 default=1000,
155 type=int)
122 156
123 args = parser.parse_args() 157 args = parser.parse_args()
124 158
diff --git a/WMD_retrieval.py b/WMD_retrieval.py
index f32372f..3328023 100644
--- a/WMD_retrieval.py
+++ b/WMD_retrieval.py
@@ -1,15 +1,19 @@
1import argparse 1import argparse
2import numpy as np 2import csv
3import random 3import random
4
5import numpy as np
4from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 6from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
5from sklearn.preprocessing import normalize 7from sklearn.preprocessing import normalize
6from Wasserstein_Distance import Wasserstein_Retriever 8
7from Wasserstein_Distance import load_embeddings, clean_corpus_using_embeddings_vocabulary, mrr_precision_at_k 9from Wasserstein_Distance import (Wasserstein_Retriever,
8import csv 10 clean_corpus_using_embeddings_vocabulary,
11 load_embeddings)
12
9 13
10def main(args): 14def main(args):
11 15
12 np.seterr(divide='ignore') # POT has issues with divide by zero errors 16 np.seterr(divide='ignore') # POT has issues with divide by zero errors
13 source_lang = args.source_lang 17 source_lang = args.source_lang
14 target_lang = args.target_lang 18 target_lang = args.target_lang
15 19
@@ -25,32 +29,38 @@ def main(args):
25 mode = args.mode 29 mode = args.mode
26 runfor = list() 30 runfor = list()
27 31
28 if (mode == 'all'): 32 if mode == 'all':
29 runfor.extend(['wmd','snk']) 33 runfor.extend(['wmd', 'snk'])
30 else: 34 else:
31 runfor.append(mode) 35 runfor.append(mode)
32 36
33 defs_source = [line.rstrip('\n') for line in open(source_defs_filename, encoding='utf8')] 37 defs_source = [
34 defs_target = [line.rstrip('\n') for line in open(target_defs_filename, encoding='utf8')] 38 line.rstrip('\n')
39 for line in open(source_defs_filename, encoding='utf8')
40 ]
41 defs_target = [
42 line.rstrip('\n')
43 for line in open(target_defs_filename, encoding='utf8')
44 ]
35 45
36 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary( 46 clean_src_corpus, clean_src_vectors, src_keys = clean_corpus_using_embeddings_vocabulary(
37 set(vectors_source.keys()), 47 set(vectors_source.keys()),
38 defs_source, 48 defs_source,
39 vectors_source, 49 vectors_source,
40 source_lang, 50 source_lang,
41 ) 51 )
42 52
43 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary( 53 clean_target_corpus, clean_target_vectors, target_keys = clean_corpus_using_embeddings_vocabulary(
44 set(vectors_target.keys()), 54 set(vectors_target.keys()),
45 defs_target, 55 defs_target,
46 vectors_target, 56 vectors_target,
47 target_lang, 57 target_lang,
48 ) 58 )
49 59
50 take = args.instances 60 take = args.instances
51 61
52 common_keys = set(src_keys).intersection(set(target_keys)) 62 common_keys = set(src_keys).intersection(set(target_keys))
53 take = min(len(common_keys), take) # you can't sample more than length 63 take = min(len(common_keys), take) # you can't sample more than length
54 experiment_keys = random.sample(common_keys, take) 64 experiment_keys = random.sample(common_keys, take)
55 65
56 instances = len(experiment_keys) 66 instances = len(experiment_keys)
@@ -58,13 +68,18 @@ def main(args):
58 clean_src_corpus = list(clean_src_corpus[experiment_keys]) 68 clean_src_corpus = list(clean_src_corpus[experiment_keys])
59 clean_target_corpus = list(clean_target_corpus[experiment_keys]) 69 clean_target_corpus = list(clean_target_corpus[experiment_keys])
60 70
61 if (not batch): 71 if not batch:
62 print(f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}') 72 print(
73 f'{source_lang} - {target_lang} : document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}'
74 )
63 75
64 del vectors_source, vectors_target, defs_source, defs_target 76 del vectors_source, vectors_target, defs_source, defs_target
65 77
66 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus) 78 vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus)
67 common = [word for word in vec.get_feature_names() if word in clean_src_vectors or word in clean_target_vectors] 79 common = [
80 word for word in vec.get_feature_names()
81 if word in clean_src_vectors or word in clean_target_vectors
82 ]
68 W_common = [] 83 W_common = []
69 for w in common: 84 for w in common:
70 if w in clean_src_vectors: 85 if w in clean_src_vectors:
@@ -72,8 +87,10 @@ def main(args):
72 else: 87 else:
73 W_common.append(np.array(clean_target_vectors[w])) 88 W_common.append(np.array(clean_target_vectors[w]))
74 89
75 if (not batch): 90 if not batch:
76 print(f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}') 91 print(
92 f'{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}'
93 )
77 94
78 W_common = np.array(W_common) 95 W_common = np.array(W_common)
79 W_common = normalize(W_common) 96 W_common = normalize(W_common)
@@ -82,26 +99,28 @@ def main(args):
82 X_train_idf = vect.transform(clean_src_corpus) 99 X_train_idf = vect.transform(clean_src_corpus)
83 X_test_idf = vect.transform(clean_target_corpus) 100 X_test_idf = vect.transform(clean_target_corpus)
84 101
85 vect_tf = CountVectorizer(vocabulary=common, dtype=np.double)
86 vect_tf.fit(clean_src_corpus + clean_target_corpus)
87 X_train_tf = vect_tf.transform(clean_src_corpus)
88 X_test_tf = vect_tf.transform(clean_target_corpus)
89
90 for metric in runfor: 102 for metric in runfor:
91 if (not batch): 103 if not batch:
92 print(f'{metric} - tfidf: {source_lang} - {target_lang}') 104 print(f'{metric} - tfidf: {source_lang} - {target_lang}')
93 105
94 clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) 106 clf = Wasserstein_Retriever(W_embed=W_common,
107 n_neighbors=5,
108 n_jobs=14,
109 sinkhorn=(metric == 'snk'))
95 clf.fit(X_train_idf[:instances], np.ones(instances)) 110 clf.fit(X_train_idf[:instances], np.ones(instances))
96 # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) 111 # dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances)
97 # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) 112 # mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds)
98 # percentage = p_at_one * 100 113 # percentage = p_at_one * 100
99 p_at_one, percentage = clf.align(X_test_idf[:instances], n_neighbors=instances) 114 p_at_one, percentage = clf.align(X_test_idf[:instances],
115 n_neighbors=instances)
100 116
101 if (not batch): 117 if not batch:
102 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%') 118 print(f'P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%')
103 else: 119 else:
104 fields = [f'{source_lang}', f'{target_lang}', f'{instances}', f'{p_at_one}', f'{percentage}'] 120 fields = [
121 f'{source_lang}', f'{target_lang}', f'{instances}',
122 f'{p_at_one}', f'{percentage}'
123 ]
105 with open(f'{metric}_retrieval_result.csv', 'a') as f: 124 with open(f'{metric}_retrieval_result.csv', 'a') as f:
106 writer = csv.writer(f) 125 writer = csv.writer(f)
107 writer.writerow(fields) 126 writer.writerow(fields)
@@ -109,16 +128,31 @@ def main(args):
109 128
110if __name__ == "__main__": 129if __name__ == "__main__":
111 130
112 parser = argparse.ArgumentParser(description='run retrieval using wmd or snk') 131 parser = argparse.ArgumentParser(
132 description='run retrieval using wmd or snk')
113 parser.add_argument('source_lang', help='source language short name') 133 parser.add_argument('source_lang', help='source language short name')
114 parser.add_argument('target_lang', help='target language short name') 134 parser.add_argument('target_lang', help='target language short name')
115 parser.add_argument('source_vector', help='path of the source vector') 135 parser.add_argument('source_vector', help='path of the source vector')
116 parser.add_argument('target_vector', help='path of the target vector') 136 parser.add_argument('target_vector', help='path of the target vector')
117 parser.add_argument('source_defs', help='path of the source definitions') 137 parser.add_argument('source_defs', help='path of the source definitions')
118 parser.add_argument('target_defs', help='path of the target definitions') 138 parser.add_argument('target_defs', help='path of the target definitions')
119 parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') 139 parser.add_argument(
120 parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') 140 '-b',
121 parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) 141 '--batch',
142 action='store_true',
143 help=
144 'running in batch (store results in csv) or running a single instance (output the results)'
145 )
146 parser.add_argument('mode',
147 choices=['all', 'wmd', 'snk'],
148 default='all',
149 help='which methods to run')
150 parser.add_argument(
151 '-n',
152 '--instances',
153 help='number of instances in each language to retrieve',
154 default=1000,
155 type=int)
122 156
123 args = parser.parse_args() 157 args = parser.parse_args()
124 158
diff --git a/Wasserstein_Distance.py b/Wasserstein_Distance.py
index 08439d2..161c13c 100644
--- a/Wasserstein_Distance.py
+++ b/Wasserstein_Distance.py
@@ -1,15 +1,14 @@
1import ot 1import numpy as np
2from sklearn.preprocessing import normalize
3from lapjv import lapjv
4from sklearn.neighbors import KNeighborsClassifier
5from sklearn.metrics import euclidean_distances 2from sklearn.metrics import euclidean_distances
6from sklearn.externals.joblib import Parallel, delayed 3from sklearn.neighbors import KNeighborsClassifier
4from sklearn.preprocessing import normalize
7from sklearn.utils import check_array 5from sklearn.utils import check_array
8from sklearn.metrics.scorer import check_scoring 6
9from pathos.multiprocessing import ProcessingPool as Pool 7import ot
10from sklearn.metrics import euclidean_distances 8from lapjv import lapjv
11import numpy as np
12from mosestokenizer import MosesTokenizer 9from mosestokenizer import MosesTokenizer
10from pathos.multiprocessing import ProcessingPool as Pool
11
13 12
14class Wasserstein_Matcher(KNeighborsClassifier): 13class Wasserstein_Matcher(KNeighborsClassifier):
15 """ 14 """
@@ -17,7 +16,13 @@ class Wasserstein_Matcher(KNeighborsClassifier):
17 Source and target distributions are l_1 normalized before computing the Wasserstein distance. 16 Source and target distributions are l_1 normalized before computing the Wasserstein distance.
18 Wasserstein is parametrized by the distances between the individual points of the distributions. 17 Wasserstein is parametrized by the distances between the individual points of the distributions.
19 """ 18 """
20 def __init__(self, W_embed, n_neighbors=1, n_jobs=1, verbose=False, sinkhorn= False, sinkhorn_reg=0.1): 19 def __init__(self,
20 W_embed,
21 n_neighbors=1,
22 n_jobs=1,
23 verbose=False,
24 sinkhorn=False,
25 sinkhorn_reg=0.1):
21 """ 26 """
22 Initialization of the class. 27 Initialization of the class.
23 Arguments 28 Arguments
@@ -29,7 +34,10 @@ class Wasserstein_Matcher(KNeighborsClassifier):
29 self.sinkhorn_reg = sinkhorn_reg 34 self.sinkhorn_reg = sinkhorn_reg
30 self.W_embed = W_embed 35 self.W_embed = W_embed
31 self.verbose = verbose 36 self.verbose = verbose
32 super(Wasserstein_Matcher, self).__init__(n_neighbors=n_neighbors, n_jobs=n_jobs, metric='precomputed', algorithm='brute') 37 super(Wasserstein_Matcher, self).__init__(n_neighbors=n_neighbors,
38 n_jobs=n_jobs,
39 metric='precomputed',
40 algorithm='brute')
33 41
34 def _wmd(self, i, row, X_train): 42 def _wmd(self, i, row, X_train):
35 union_idx = np.union1d(X_train[i].indices, row.indices) 43 union_idx = np.union1d(X_train[i].indices, row.indices)
@@ -38,9 +46,16 @@ class Wasserstein_Matcher(KNeighborsClassifier):
38 bow_i = X_train[i, union_idx].A.ravel() 46 bow_i = X_train[i, union_idx].A.ravel()
39 bow_j = row[:, union_idx].A.ravel() 47 bow_j = row[:, union_idx].A.ravel()
40 if self.sinkhorn: 48 if self.sinkhorn:
41 return ot.sinkhorn2(bow_i, bow_j, W_dist, self.sinkhorn_reg, numItermax=50, method='sinkhorn_stabilized',)[0] 49 return ot.sinkhorn2(
50 bow_i,
51 bow_j,
52 W_dist,
53 self.sinkhorn_reg,
54 numItermax=50,
55 method='sinkhorn_stabilized',
56 )[0]
42 else: 57 else:
43 return ot.emd2(bow_i, bow_j, W_dist) 58 return ot.emd2(bow_i, bow_j, W_dist)
44 59
45 def _wmd_row(self, row): 60 def _wmd_row(self, row):
46 X_train = self._fit_X 61 X_train = self._fit_X
@@ -52,28 +67,31 @@ class Wasserstein_Matcher(KNeighborsClassifier):
52 67
53 if X_train is None: 68 if X_train is None:
54 X_train = self._fit_X 69 X_train = self._fit_X
55 pool = Pool(nodes=self.n_jobs) # Parallelization of the calculation of the distances 70 pool = Pool(nodes=self.n_jobs
56 dist = pool.map(self._wmd_row, X_test) 71 ) # Parallelization of the calculation of the distances
72 dist = pool.map(self._wmd_row, X_test)
57 return np.array(dist) 73 return np.array(dist)
58 74
59 def fit(self, X, y): # X_train_idf 75 def fit(self, X, y): # X_train_idf
60 X = check_array(X, accept_sparse='csr', copy=True) # check if array is sparse 76 X = check_array(X, accept_sparse='csr',
77 copy=True) # check if array is sparse
61 X = normalize(X, norm='l1', copy=False) 78 X = normalize(X, norm='l1', copy=False)
62 return super(Wasserstein_Matcher, self).fit(X, y) # X_train_idf, np_ones(document collection size) 79 return super(Wasserstein_Matcher, self).fit(
80 X, y) # X_train_idf, np_ones(document collection size)
63 81
64 def predict(self, X): 82 def predict(self, X):
65 X = check_array(X, accept_sparse='csr', copy=True) 83 X = check_array(X, accept_sparse='csr', copy=True)
66 X = normalize(X, norm='l1', copy=False) 84 X = normalize(X, norm='l1', copy=False)
67 dist = self._pairwise_wmd(X) 85 dist = self._pairwise_wmd(X)
68 dist = dist * 1000 # for lapjv, small floating point numbers are evil 86 dist = dist * 1000 # for lapjv, small floating point numbers are evil
69 return super(Wasserstein_Matcher, self).predict(dist) 87 return super(Wasserstein_Matcher, self).predict(dist)
70 88
71 def kneighbors(self, X, n_neighbors=1): # X : X_train_idf 89 def kneighbors(self, X, n_neighbors=1): # X : X_train_idf
72 X = check_array(X, accept_sparse='csr', copy=True) 90 X = check_array(X, accept_sparse='csr', copy=True)
73 X = normalize(X, norm='l1', copy=False) 91 X = normalize(X, norm='l1', copy=False)
74 dist = self._pairwise_wmd(X) 92 dist = self._pairwise_wmd(X)
75 dist = dist * 1000 # for lapjv, small floating point numbers are evil 93 dist = dist * 1000 # for lapjv, small floating point numbers are evil
76 return lapjv(dist) # and here is the matching part 94 return lapjv(dist) # and here is the matching part
77 95
78 96
79class Wasserstein_Retriever(KNeighborsClassifier): 97class Wasserstein_Retriever(KNeighborsClassifier):
@@ -82,7 +100,13 @@ class Wasserstein_Retriever(KNeighborsClassifier):
82 Source and target distributions are l_1 normalized before computing the Wasserstein distance. 100 Source and target distributions are l_1 normalized before computing the Wasserstein distance.
83 Wasserstein is parametrized by the distances between the individual points of the distributions. 101 Wasserstein is parametrized by the distances between the individual points of the distributions.
84 """ 102 """
85 def __init__(self, W_embed, n_neighbors=1, n_jobs=1, verbose=False, sinkhorn= False, sinkhorn_reg=0.1): 103 def __init__(self,
104 W_embed,
105 n_neighbors=1,
106 n_jobs=1,
107 verbose=False,
108 sinkhorn=False,
109 sinkhorn_reg=0.1):
86 """ 110 """
87 Initialization of the class. 111 Initialization of the class.
88 Arguments 112 Arguments
@@ -94,7 +118,10 @@ class Wasserstein_Retriever(KNeighborsClassifier):
94 self.sinkhorn_reg = sinkhorn_reg 118 self.sinkhorn_reg = sinkhorn_reg
95 self.W_embed = W_embed 119 self.W_embed = W_embed
96 self.verbose = verbose 120 self.verbose = verbose
97 super(Wasserstein_Retriever, self).__init__(n_neighbors=n_neighbors, n_jobs=n_jobs, metric='precomputed', algorithm='brute') 121 super(Wasserstein_Retriever, self).__init__(n_neighbors=n_neighbors,
122 n_jobs=n_jobs,
123 metric='precomputed',
124 algorithm='brute')
98 125
99 def _wmd(self, i, row, X_train): 126 def _wmd(self, i, row, X_train):
100 union_idx = np.union1d(X_train[i].indices, row.indices) 127 union_idx = np.union1d(X_train[i].indices, row.indices)
@@ -103,9 +130,16 @@ class Wasserstein_Retriever(KNeighborsClassifier):
103 bow_i = X_train[i, union_idx].A.ravel() 130 bow_i = X_train[i, union_idx].A.ravel()
104 bow_j = row[:, union_idx].A.ravel() 131 bow_j = row[:, union_idx].A.ravel()
105 if self.sinkhorn: 132 if self.sinkhorn:
106 return ot.sinkhorn2(bow_i, bow_j, W_dist, self.sinkhorn_reg, numItermax=50, method='sinkhorn_stabilized',)[0] 133 return ot.sinkhorn2(
134 bow_i,
135 bow_j,
136 W_dist,
137 self.sinkhorn_reg,
138 numItermax=50,
139 method='sinkhorn_stabilized',
140 )[0]
107 else: 141 else:
108 return ot.emd2(bow_i, bow_j, W_dist) 142 return ot.emd2(bow_i, bow_j, W_dist)
109 143
110 def _wmd_row(self, row): 144 def _wmd_row(self, row):
111 X_train = self._fit_X 145 X_train = self._fit_X
@@ -117,8 +151,8 @@ class Wasserstein_Retriever(KNeighborsClassifier):
117 151
118 if X_train is None: 152 if X_train is None:
119 X_train = self._fit_X 153 X_train = self._fit_X
120 pool = Pool(nodes=self.n_jobs) # Parallelization of the calculation of the distances 154 pool = Pool(nodes=self.n_jobs)
121 dist = pool.map(self._wmd_row, X_test) 155 dist = pool.map(self._wmd_row, X_test)
122 return np.array(dist) 156 return np.array(dist)
123 157
124 def fit(self, X, y): 158 def fit(self, X, y):
@@ -144,8 +178,8 @@ class Wasserstein_Retriever(KNeighborsClassifier):
144 precision at one and percentage values 178 precision at one and percentage values
145 179
146 """ 180 """
147 dist, preds = self.kneighbors(X, n_neighbors) 181 _, preds = self.kneighbors(X, n_neighbors)
148 mrr, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds) 182 _, p_at_one = mrr_precision_at_k(list(range(len(preds))), preds)
149 percentage = p_at_one * 100 183 percentage = p_at_one * 100
150 return (p_at_one, percentage) 184 return (p_at_one, percentage)
151 185
@@ -168,7 +202,8 @@ def load_embeddings(path, dimension=300):
168 fp.seek(0) 202 fp.seek(0)
169 for line in fp: 203 for line in fp:
170 elems = line.split() 204 elems = line.split()
171 vectors[" ".join(elems[:-dimension])] = " ".join(elems[-dimension:]) 205 vectors[" ".join(elems[:-dimension])] = " ".join(
206 elems[-dimension:])
172 return vectors 207 return vectors
173 208
174 209
@@ -177,7 +212,7 @@ def clean_corpus_using_embeddings_vocabulary(
177 corpus, 212 corpus,
178 vectors, 213 vectors,
179 language, 214 language,
180 ): 215):
181 ''' 216 '''
182 Cleans corpus using the dictionary of embeddings. 217 Cleans corpus using the dictionary of embeddings.
183 Any word without an associated embedding in the dictionary is ignored. 218 Any word without an associated embedding in the dictionary is ignored.
@@ -192,7 +227,8 @@ def clean_corpus_using_embeddings_vocabulary(
192 for word in words: 227 for word in words:
193 if word in words_we_want: 228 if word in words_we_want:
194 clean_doc.append(word + '__%s' % language) 229 clean_doc.append(word + '__%s' % language)
195 clean_vectors[word + '__%s' % language] = np.array(vectors[word].split()).astype(np.float) 230 clean_vectors[word + '__%s' % language] = np.array(
231 vectors[word].split()).astype(np.float)
196 if len(clean_doc) > 3 and len(clean_doc) < 25: 232 if len(clean_doc) > 3 and len(clean_doc) < 25:
197 keys.append(key) 233 keys.append(key)
198 clean_corpus.append(' '.join(clean_doc)) 234 clean_corpus.append(' '.join(clean_doc))
@@ -208,10 +244,9 @@ def mrr_precision_at_k(golden, preds, k_list=[1,]):
208 precision_at = np.zeros(len(k_list)) 244 precision_at = np.zeros(len(k_list))
209 for key, elem in enumerate(golden): 245 for key, elem in enumerate(golden):
210 if elem in preds[key]: 246 if elem in preds[key]:
211 location = np.where(preds[key]==elem)[0][0] 247 location = np.where(preds[key] == elem)[0][0]
212 my_score += 1/(1+ location) 248 my_score += 1 / (1 + location)
213 for k_index, k_value in enumerate(k_list): 249 for k_index, k_value in enumerate(k_list):
214 if location < k_value: 250 if location < k_value:
215 precision_at[k_index] += 1 251 precision_at[k_index] += 1
216 return my_score/len(golden), (precision_at/len(golden))[0] 252 return my_score / len(golden), (precision_at / len(golden))[0]
217