aboutsummaryrefslogtreecommitdiffstats
path: root/sentence_embedding.py
diff options
context:
space:
mode:
Diffstat (limited to 'sentence_embedding.py')
-rw-r--r--sentence_embedding.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/sentence_embedding.py b/sentence_embedding.py
index d312de2..2ac6720 100644
--- a/sentence_embedding.py
+++ b/sentence_embedding.py
@@ -15,7 +15,7 @@ def main(args):
15 run_paradigm = list() 15 run_paradigm = list()
16 16
17 if args.paradigm == "all": 17 if args.paradigm == "all":
18 run_paradigm.extend("matching", "retrieval") 18 run_paradigm.extend(("matching", "retrieval"))
19 else: 19 else:
20 run_paradigm.append(args.paradigm) 20 run_paradigm.append(args.paradigm)
21 21
@@ -98,8 +98,8 @@ def main(args):
98 for paradigm in run_paradigm: 98 for paradigm in run_paradigm:
99 if paradigm == "matching": 99 if paradigm == "matching":
100 100
101 cost_matrix = cost_matrix * -1000 101 matching_cost_matrix = cost_matrix * -1000
102 row_ind, col_ind, a = lapjv(cost_matrix, verbose=False) 102 row_ind, col_ind, a = lapjv(matching_cost_matrix, verbose=False)
103 103
104 result = zip(row_ind, col_ind) 104 result = zip(row_ind, col_ind)
105 hit_at_one = len([x for x, y in result if x == y]) 105 hit_at_one = len([x for x, y in result if x == y])
@@ -117,9 +117,9 @@ def main(args):
117 f"{percentage}", 117 f"{percentage}",
118 ] 118 ]
119 119
120 with open("semb_matcing_results.csv", "a") as f: 120 with open("semb_matcing_results.csv", "a") as f:
121 writer = csv.writer(f) 121 writer = csv.writer(f)
122 writer.writerow(fields) 122 writer.writerow(fields)
123 123
124 if paradigm == "retrieval": 124 if paradigm == "retrieval":
125 125