From a3d0df44e1530803716abe4d2b66327eefdaf703 Mon Sep 17 00:00:00 2001 From: Yigit Sever Date: Wed, 25 Sep 2019 16:10:30 +0300 Subject: Fix argument typo --- sentence_embedding.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sentence_embedding.py b/sentence_embedding.py index 0cd5361..d312de2 100644 --- a/sentence_embedding.py +++ b/sentence_embedding.py @@ -12,12 +12,12 @@ from Wasserstein_Distance import load_embeddings, process_corpus def main(args): - run_method = list() + run_paradigm = list() - if input_paradigm == "all": + if args.paradigm == "all": run_paradigm.extend("matching", "retrieval") else: - run_paradigm.append(input_paradigm) + run_paradigm.append(args.paradigm) source_lang = args.source_lang target_lang = args.target_lang @@ -96,7 +96,7 @@ def main(args): cost_matrix = np.matmul(S_emb_source, S_emb_target_transpose) for paradigm in run_paradigm: - if paradigm == 'matching': + if paradigm == "matching": cost_matrix = cost_matrix * -1000 row_ind, col_ind, a = lapjv(cost_matrix, verbose=False) @@ -121,9 +121,11 @@ def main(args): writer = csv.writer(f) writer.writerow(fields) - if paradigm == 'retrieval': + if paradigm == "retrieval": - hit_at_one = len([x for x, y in enumerate(cost_matrix.argmax(axis=1)) if x == y]) + hit_at_one = len( + [x for x, y in enumerate(cost_matrix.argmax(axis=1)) if x == y] + ) percentage = hit_at_one / instances * 100 if not batch: -- cgit v1.2.3-70-g09d2