diff options
author | Yigit Sever | 2019-09-21 16:55:19 +0300 |
---|---|---|
committer | Yigit Sever | 2019-09-21 16:55:19 +0300 |
commit | 73df12b6787304e6982e4543d63700beabc88085 (patch) | |
tree | 2d222938890e73bfa97d4972deb266311a3496a1 | |
parent | 1558c5e3ea8e34286ad587a8473f5121d5e8d289 (diff) | |
download | Evaluating-Dictionary-Alignment-73df12b6787304e6982e4543d63700beabc88085.tar.gz Evaluating-Dictionary-Alignment-73df12b6787304e6982e4543d63700beabc88085.tar.bz2 Evaluating-Dictionary-Alignment-73df12b6787304e6982e4543d63700beabc88085.zip |
Fix typo
-rw-r--r-- | WMD_retrieval.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/WMD_retrieval.py b/WMD_retrieval.py index b49ba7d..32f3b5d 100644 --- a/WMD_retrieval.py +++ b/WMD_retrieval.py | |||
@@ -92,7 +92,7 @@ def main(args): | |||
92 | if (not batch): | 92 | if (not batch): |
93 | print(f'{metric} - tfidf: {source_lang} - {target_lang}') | 93 | print(f'{metric} - tfidf: {source_lang} - {target_lang}') |
94 | 94 | ||
95 | clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) | 95 | clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) |
96 | clf.fit(X_train_idf[:instances], np.ones(instances)) | 96 | clf.fit(X_train_idf[:instances], np.ones(instances)) |
97 | dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) | 97 | dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) |
98 | mrr, p_at_1 = mrr_precision_at_k(list(range(len(preds))), preds) | 98 | mrr, p_at_1 = mrr_precision_at_k(list(range(len(preds))), preds) |
@@ -118,7 +118,7 @@ if __name__ == "__main__": | |||
118 | parser.add_argument('target_defs', help='path of the target definitions') | 118 | parser.add_argument('target_defs', help='path of the target definitions') |
119 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') | 119 | parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') |
120 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') | 120 | parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') |
121 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=2000, type=int) | 121 | parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) |
122 | 122 | ||
123 | args = parser.parse_args() | 123 | args = parser.parse_args() |
124 | 124 | ||