From 73df12b6787304e6982e4543d63700beabc88085 Mon Sep 17 00:00:00 2001 From: Yigit Sever Date: Sat, 21 Sep 2019 16:55:19 +0300 Subject: Fix typo --- WMD_retrieval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/WMD_retrieval.py b/WMD_retrieval.py index b49ba7d..32f3b5d 100644 --- a/WMD_retrieval.py +++ b/WMD_retrieval.py @@ -92,7 +92,7 @@ def main(args): if (not batch): print(f'{metric} - tfidf: {source_lang} - {target_lang}') - clf = WassersteinDistances(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) + clf = Wasserstein_Retriever(W_embed=W_common, n_neighbors=5, n_jobs=14, sinkhorn=(metric == 'snk')) clf.fit(X_train_idf[:instances], np.ones(instances)) dist, preds = clf.kneighbors(X_test_idf[:instances], n_neighbors=instances) mrr, p_at_1 = mrr_precision_at_k(list(range(len(preds))), preds) @@ -118,7 +118,7 @@ if __name__ == "__main__": parser.add_argument('target_defs', help='path of the target definitions') parser.add_argument('-b', '--batch', action='store_true', help='running in batch (store results in csv) or running a single instance (output the results)') parser.add_argument('mode', choices=['all', 'wmd', 'snk'], default='all', help='which methods to run') - parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=2000, type=int) + parser.add_argument('-n', '--instances', help='number of instances in each language to retrieve', default=1000, type=int) args = parser.parse_args() -- cgit v1.2.3-70-g09d2