diff options
| author | Yigit Sever | 2019-09-25 17:49:34 +0300 |
|---|---|---|
| committer | Yigit Sever | 2019-09-25 17:49:34 +0300 |
| commit | 442a1895fe567502ec5fec20a62083ea090f38cc (patch) | |
| tree | 5f449090f221ee26ccbed2f184641b6cc7c6bf66 | |
| parent | 6ca2b8a7fd444d6f2197e1659b357af9e0fc2c64 (diff) | |
| download | Evaluating-Dictionary-Alignment-442a1895fe567502ec5fec20a62083ea090f38cc.tar.gz Evaluating-Dictionary-Alignment-442a1895fe567502ec5fec20a62083ea090f38cc.tar.bz2 Evaluating-Dictionary-Alignment-442a1895fe567502ec5fec20a62083ea090f38cc.zip | |
Bugfix
| -rw-r--r-- | sentence_embedding.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/sentence_embedding.py b/sentence_embedding.py index d312de2..2ac6720 100644 --- a/sentence_embedding.py +++ b/sentence_embedding.py | |||
| @@ -15,7 +15,7 @@ def main(args): | |||
| 15 | run_paradigm = list() | 15 | run_paradigm = list() |
| 16 | 16 | ||
| 17 | if args.paradigm == "all": | 17 | if args.paradigm == "all": |
| 18 | run_paradigm.extend("matching", "retrieval") | 18 | run_paradigm.extend(("matching", "retrieval")) |
| 19 | else: | 19 | else: |
| 20 | run_paradigm.append(args.paradigm) | 20 | run_paradigm.append(args.paradigm) |
| 21 | 21 | ||
| @@ -98,8 +98,8 @@ def main(args): | |||
| 98 | for paradigm in run_paradigm: | 98 | for paradigm in run_paradigm: |
| 99 | if paradigm == "matching": | 99 | if paradigm == "matching": |
| 100 | 100 | ||
| 101 | cost_matrix = cost_matrix * -1000 | 101 | matching_cost_matrix = cost_matrix * -1000 |
| 102 | row_ind, col_ind, a = lapjv(cost_matrix, verbose=False) | 102 | row_ind, col_ind, a = lapjv(matching_cost_matrix, verbose=False) |
| 103 | 103 | ||
| 104 | result = zip(row_ind, col_ind) | 104 | result = zip(row_ind, col_ind) |
| 105 | hit_at_one = len([x for x, y in result if x == y]) | 105 | hit_at_one = len([x for x, y in result if x == y]) |
| @@ -117,9 +117,9 @@ def main(args): | |||
| 117 | f"{percentage}", | 117 | f"{percentage}", |
| 118 | ] | 118 | ] |
| 119 | 119 | ||
| 120 | with open("semb_matcing_results.csv", "a") as f: | 120 | with open("semb_matcing_results.csv", "a") as f: |
| 121 | writer = csv.writer(f) | 121 | writer = csv.writer(f) |
| 122 | writer.writerow(fields) | 122 | writer.writerow(fields) |
| 123 | 123 | ||
| 124 | if paradigm == "retrieval": | 124 | if paradigm == "retrieval": |
| 125 | 125 | ||
