diff options
-rw-r--r-- | WMD.py | 16 |
1 files changed, 7 insertions, 9 deletions
@@ -67,12 +67,6 @@ def main(args): | |||
67 | clean_src_corpus = list(clean_src_corpus[experiment_keys]) | 67 | clean_src_corpus = list(clean_src_corpus[experiment_keys]) |
68 | clean_target_corpus = list(clean_target_corpus[experiment_keys]) | 68 | clean_target_corpus = list(clean_target_corpus[experiment_keys]) |
69 | 69 | ||
70 | if not batch: | ||
71 | print( | ||
72 | f"{source_lang} - {target_lang} " | ||
73 | + f" document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}" | ||
74 | ) | ||
75 | |||
76 | del vectors_source, vectors_target, defs_source, defs_target | 70 | del vectors_source, vectors_target, defs_source, defs_target |
77 | 71 | ||
78 | vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus) | 72 | vec = CountVectorizer().fit(clean_src_corpus + clean_target_corpus) |
@@ -89,7 +83,11 @@ def main(args): | |||
89 | W_common.append(np.array(clean_target_vectors[w])) | 83 | W_common.append(np.array(clean_target_vectors[w])) |
90 | 84 | ||
91 | if not batch: | 85 | if not batch: |
92 | print(f"{source_lang} - {target_lang}: the vocabulary size is {len(W_common)}") | 86 | print( |
87 | f"{source_lang} - {target_lang}\n" | ||
88 | + f" document sizes: {len(clean_src_corpus)}, {len(clean_target_corpus)}\n" | ||
89 | + f" vocabulary size: {len(W_common)}" | ||
90 | ) | ||
93 | 91 | ||
94 | W_common = np.array(W_common) | 92 | W_common = np.array(W_common) |
95 | W_common = normalize(W_common) | 93 | W_common = normalize(W_common) |
@@ -107,7 +105,7 @@ def main(args): | |||
107 | 105 | ||
108 | for metric in run_method: | 106 | for metric in run_method: |
109 | if not batch: | 107 | if not batch: |
110 | print(f"{metric}: {source_lang} - {target_lang}") | 108 | print(f"{paradigm} - {metric} on {source_lang} - {target_lang}") |
111 | 109 | ||
112 | clf = WassersteinDriver( | 110 | clf = WassersteinDriver( |
113 | W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == "snk") | 111 | W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == "snk") |
@@ -118,7 +116,7 @@ def main(args): | |||
118 | ) | 116 | ) |
119 | 117 | ||
120 | if not batch: | 118 | if not batch: |
121 | print(f"P @ 1: {p_at_one}\ninstances: {instances}\n{percentage}%") | 119 | print(f"P @ 1: {p_at_one}\n{percentage}% {instances} definitions\n") |
122 | else: | 120 | else: |
123 | fields = [ | 121 | fields = [ |
124 | f"{source_lang}", | 122 | f"{source_lang}", |