From 442a1895fe567502ec5fec20a62083ea090f38cc Mon Sep 17 00:00:00 2001 From: Yigit Sever Date: Wed, 25 Sep 2019 17:49:34 +0300 Subject: Bugfix --- sentence_embedding.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'sentence_embedding.py') diff --git a/sentence_embedding.py b/sentence_embedding.py index d312de2..2ac6720 100644 --- a/sentence_embedding.py +++ b/sentence_embedding.py @@ -15,7 +15,7 @@ def main(args): run_paradigm = list() if args.paradigm == "all": - run_paradigm.extend("matching", "retrieval") + run_paradigm.extend(("matching", "retrieval")) else: run_paradigm.append(args.paradigm) @@ -98,8 +98,8 @@ def main(args): for paradigm in run_paradigm: if paradigm == "matching": - cost_matrix = cost_matrix * -1000 - row_ind, col_ind, a = lapjv(cost_matrix, verbose=False) + matching_cost_matrix = cost_matrix * -1000 + row_ind, col_ind, a = lapjv(matching_cost_matrix, verbose=False) result = zip(row_ind, col_ind) hit_at_one = len([x for x, y in result if x == y]) @@ -117,9 +117,9 @@ def main(args): f"{percentage}", ] - with open("semb_matcing_results.csv", "a") as f: - writer = csv.writer(f) - writer.writerow(fields) + with open("semb_matcing_results.csv", "a") as f: + writer = csv.writer(f) + writer.writerow(fields) if paradigm == "retrieval": -- cgit v1.2.3-70-g09d2