diff options
| -rw-r--r-- | sentence_embedding.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/sentence_embedding.py b/sentence_embedding.py index 0cd5361..d312de2 100644 --- a/sentence_embedding.py +++ b/sentence_embedding.py | |||
| @@ -12,12 +12,12 @@ from Wasserstein_Distance import load_embeddings, process_corpus | |||
| 12 | 12 | ||
| 13 | def main(args): | 13 | def main(args): |
| 14 | 14 | ||
| 15 | run_method = list() | 15 | run_paradigm = list() |
| 16 | 16 | ||
| 17 | if input_paradigm == "all": | 17 | if args.paradigm == "all": |
| 18 | run_paradigm.extend("matching", "retrieval") | 18 | run_paradigm.extend("matching", "retrieval") |
| 19 | else: | 19 | else: |
| 20 | run_paradigm.append(input_paradigm) | 20 | run_paradigm.append(args.paradigm) |
| 21 | 21 | ||
| 22 | source_lang = args.source_lang | 22 | source_lang = args.source_lang |
| 23 | target_lang = args.target_lang | 23 | target_lang = args.target_lang |
| @@ -96,7 +96,7 @@ def main(args): | |||
| 96 | cost_matrix = np.matmul(S_emb_source, S_emb_target_transpose) | 96 | cost_matrix = np.matmul(S_emb_source, S_emb_target_transpose) |
| 97 | 97 | ||
| 98 | for paradigm in run_paradigm: | 98 | for paradigm in run_paradigm: |
| 99 | if paradigm == 'matching': | 99 | if paradigm == "matching": |
| 100 | 100 | ||
| 101 | cost_matrix = cost_matrix * -1000 | 101 | cost_matrix = cost_matrix * -1000 |
| 102 | row_ind, col_ind, a = lapjv(cost_matrix, verbose=False) | 102 | row_ind, col_ind, a = lapjv(cost_matrix, verbose=False) |
| @@ -121,9 +121,11 @@ def main(args): | |||
| 121 | writer = csv.writer(f) | 121 | writer = csv.writer(f) |
| 122 | writer.writerow(fields) | 122 | writer.writerow(fields) |
| 123 | 123 | ||
| 124 | if paradigm == 'retrieval': | 124 | if paradigm == "retrieval": |
| 125 | 125 | ||
| 126 | hit_at_one = len([x for x, y in enumerate(cost_matrix.argmax(axis=1)) if x == y]) | 126 | hit_at_one = len( |
| 127 | [x for x, y in enumerate(cost_matrix.argmax(axis=1)) if x == y] | ||
| 128 | ) | ||
| 127 | percentage = hit_at_one / instances * 100 | 129 | percentage = hit_at_one / instances * 100 |
| 128 | 130 | ||
| 129 | if not batch: | 131 | if not batch: |
