aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--sentence_embedding.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/sentence_embedding.py b/sentence_embedding.py
index 0cd5361..d312de2 100644
--- a/sentence_embedding.py
+++ b/sentence_embedding.py
@@ -12,12 +12,12 @@ from Wasserstein_Distance import load_embeddings, process_corpus
12 12
13def main(args): 13def main(args):
14 14
15 run_method = list() 15 run_paradigm = list()
16 16
17 if input_paradigm == "all": 17 if args.paradigm == "all":
18 run_paradigm.extend("matching", "retrieval") 18 run_paradigm.extend("matching", "retrieval")
19 else: 19 else:
20 run_paradigm.append(input_paradigm) 20 run_paradigm.append(args.paradigm)
21 21
22 source_lang = args.source_lang 22 source_lang = args.source_lang
23 target_lang = args.target_lang 23 target_lang = args.target_lang
@@ -96,7 +96,7 @@ def main(args):
96 cost_matrix = np.matmul(S_emb_source, S_emb_target_transpose) 96 cost_matrix = np.matmul(S_emb_source, S_emb_target_transpose)
97 97
98 for paradigm in run_paradigm: 98 for paradigm in run_paradigm:
99 if paradigm == 'matching': 99 if paradigm == "matching":
100 100
101 cost_matrix = cost_matrix * -1000 101 cost_matrix = cost_matrix * -1000
102 row_ind, col_ind, a = lapjv(cost_matrix, verbose=False) 102 row_ind, col_ind, a = lapjv(cost_matrix, verbose=False)
@@ -121,9 +121,11 @@ def main(args):
121 writer = csv.writer(f) 121 writer = csv.writer(f)
122 writer.writerow(fields) 122 writer.writerow(fields)
123 123
124 if paradigm == 'retrieval': 124 if paradigm == "retrieval":
125 125
126 hit_at_one = len([x for x, y in enumerate(cost_matrix.argmax(axis=1)) if x == y]) 126 hit_at_one = len(
127 [x for x, y in enumerate(cost_matrix.argmax(axis=1)) if x == y]
128 )
127 percentage = hit_at_one / instances * 100 129 percentage = hit_at_one / instances * 100
128 130
129 if not batch: 131 if not batch: