diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/tsv_creator.py | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/scripts/tsv_creator.py b/scripts/tsv_creator.py new file mode 100644 index 0000000..7d587c5 --- /dev/null +++ b/scripts/tsv_creator.py | |||
@@ -0,0 +1,101 @@ | |||
1 | import argparse | ||
2 | |||
3 | parser = argparse.ArgumentParser(description='Create .tsv file from two wordnet definitions') | ||
4 | parser.add_argument('source_lang', help='source language short name') | ||
5 | parser.add_argument('target_lang', help='target language short name') | ||
6 | parser.add_argument('source_defs', help='path of the source definitions') | ||
7 | parser.add_argument('target_defs', help='path of the target definitions') | ||
8 | parser.add_argument('-n', '--set_aside', help='set aside to validate on', type=int) | ||
9 | |||
10 | args = parser.parse_args() | ||
11 | |||
12 | source_lang = args.source_lang | ||
13 | target_lang = args.target_lang | ||
14 | |||
15 | from DataHelper.Loader import load_def_from_file as load_def | ||
16 | |||
17 | source_defs_filename = args.source_defs | ||
18 | target_defs_filename = args.target_defs | ||
19 | defs_source = load_def(source_defs_filename) | ||
20 | defs_target = load_def(target_defs_filename) | ||
21 | |||
22 | import numpy as np | ||
23 | from re import sub | ||
24 | from mosestokenizer import * | ||
25 | |||
26 | def clean_corpus_suffix(corpus, language): | ||
27 | ''' | ||
28 | Adds '__target-language' and '__source-language' at the end of the words | ||
29 | ''' | ||
30 | clean_corpus = [] | ||
31 | tokenize = MosesTokenizer(language) | ||
32 | for definition in corpus: | ||
33 | definition = sub(r"'", '', definition) | ||
34 | definition = sub(r"[^\w]", ' ', definition) | ||
35 | clean_doc = [] | ||
36 | words = tokenize(definition) | ||
37 | for word in words: | ||
38 | clean_doc.append(word + '__%s' % language) | ||
39 | clean_corpus.append(' '.join(clean_doc)) | ||
40 | return clean_corpus | ||
41 | |||
42 | clean_source_corpus = clean_corpus_suffix(defs_source, source_lang) | ||
43 | clean_target_corpus = clean_corpus_suffix(defs_target, target_lang) | ||
44 | |||
45 | assert len(clean_source_corpus) == len(clean_target_corpus) | ||
46 | |||
47 | set_aside = args.set_aside | ||
48 | |||
49 | source_predict = clean_source_corpus[-set_aside:] | ||
50 | target_predict = clean_target_corpus[-set_aside:] | ||
51 | labels_predict = [1] * set_aside # placeholder, won't be used, we can use 1 because they're correct | ||
52 | |||
53 | clean_source_corpus = clean_source_corpus[:-set_aside] | ||
54 | clean_target_corpus = clean_target_corpus[:-set_aside] | ||
55 | |||
56 | size = len(clean_source_corpus) | ||
57 | |||
58 | import math | ||
59 | import random | ||
60 | |||
61 | def create_pos_neg_samples(length): | ||
62 | indices = list(range(length)) | ||
63 | halfsize = math.ceil(length / 2) | ||
64 | neg_points = random.sample(indices, halfsize) | ||
65 | neg_indices = list(neg_points) | ||
66 | random.shuffle(neg_indices) | ||
67 | |||
68 | for (index, point) in zip(neg_indices, neg_points): | ||
69 | indices[point] = index | ||
70 | |||
71 | labels = [1] * length | ||
72 | |||
73 | for i in neg_points: | ||
74 | labels[i] = 0 | ||
75 | |||
76 | return indices, labels | ||
77 | |||
78 | while True: | ||
79 | indices, labels = create_pos_neg_samples(size) | ||
80 | shuffled_target = [clean_target_corpus[index] for index in indices] | ||
81 | check = [clean for clean, shuf in zip(clean_target_corpus, shuffled_target) if clean == shuf] | ||
82 | halfsize = math.ceil(size/2) | ||
83 | try: | ||
84 | assert len(check) == halfsize | ||
85 | except AssertionError: | ||
86 | print(f'rolling again: {len(check)} vs {halfsize}') | ||
87 | else: | ||
88 | break | ||
89 | |||
90 | assert len(clean_source_corpus) == len(shuffled_target) == size | ||
91 | assert len(labels) == len(clean_source_corpus) == len(shuffled_target) | ||
92 | |||
93 | import csv | ||
94 | |||
95 | with open(f'/home/syigit/tsv_data/{source_lang}_{target_lang}_1000_data.tsv', 'w', encoding='utf8', newline='') as tsv_file: | ||
96 | tsv_writer = csv.writer(tsv_file, delimiter='\t', lineterminator='\n') | ||
97 | tsv_writer.writerow([f'{source_lang} definition', f'{target_lang} definition', 'is same']) | ||
98 | for row in zip(clean_source_corpus, shuffled_target, labels): | ||
99 | tsv_writer.writerow(row) | ||
100 | for row in zip(source_predict, target_predict, labels_predict): | ||
101 | tsv_writer.writerow(row) | ||