aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--learn_and_predict.py29
1 files changed, 18 insertions, 11 deletions
diff --git a/learn_and_predict.py b/learn_and_predict.py
index 36c56f2..961b190 100644
--- a/learn_and_predict.py
+++ b/learn_and_predict.py
@@ -1,14 +1,14 @@
1import argparse 1import argparse
2import csv 2import csv
3 3
4import numpy as np
5
6import keras 4import keras
7import keras.backend as K 5import keras.backend as K
8from Helpers import Data, Get_Embedding 6import numpy as np
9from keras.layers import LSTM, Embedding, Input, Lambda, concatenate 7from keras.layers import LSTM, Embedding, Input, Lambda, concatenate
10from keras.models import Model 8from keras.models import Model
11 9
10from Helpers import Data, Get_Embedding
11
12 12
13def get_learning_rate(epoch=None, model=None): 13def get_learning_rate(epoch=None, model=None):
14 return np.round(float(K.get_value(model.optimizer.lr)), 5) 14 return np.round(float(K.get_value(model.optimizer.lr)), 5)
@@ -151,40 +151,47 @@ if __name__ == "__main__":
151 parser = argparse.ArgumentParser() 151 parser = argparse.ArgumentParser()
152 152
153 parser.add_argument( 153 parser.add_argument(
154 "-sl", "--source_lang", type=str, help="Source language.", default="english" 154 "-sl", "--source_lang", type=str, help="Source language.", required=True
155 )
156 parser.add_argument(
157 "-tl", "--target_lang", type=str, help="Target language.", required=True
155 ) 158 )
156 parser.add_argument( 159 parser.add_argument(
157 "-tl", "--target_lang", type=str, help="Target language.", default="italian" 160 "-df", "--data_file", type=str, help="Path to dataset.", required=True
158 ) 161 )
159 parser.add_argument("-df", "--data_file", type=str, help="Path to dataset.")
160 parser.add_argument( 162 parser.add_argument(
161 "-es", 163 "-es",
162 "--source_emb_file", 164 "--source_emb_file",
163 type=str, 165 type=str,
164 help="Path to Source (English) Embedding File.", 166 help="Path to source embedding file.",
167 required=True,
165 ) 168 )
166 parser.add_argument( 169 parser.add_argument(
167 "-et", "--target_emb_file", type=str, help="Path to Target Embedding File." 170 "-et",
171 "--target_emb_file",
172 type=str,
173 help="Path to target embedding file.",
174 required=True,
168 ) 175 )
169 parser.add_argument( 176 parser.add_argument(
170 "-l", 177 "-l",
171 "--max_len", 178 "--max_len",
172 type=int, 179 type=int,
173 help="Maximum number of words in a sentence.", 180 help="Maximum number of words in a sentence.",
174 default=20, 181 default=25,
175 ) 182 )
176 parser.add_argument( 183 parser.add_argument(
177 "-z", 184 "-z",
178 "--hidden_size", 185 "--hidden_size",
179 type=int, 186 type=int,
180 help="Number of Units in LSTM layer.", 187 help="Number of units in LSTM layer.",
181 default=50, 188 default=50,
182 ) 189 )
183 parser.add_argument( 190 parser.add_argument(
184 "-b", 191 "-b",
185 "--batch", 192 "--batch",
186 action="store_true", 193 action="store_true",
187 help="running in batch (store results to csv) or" 194 help="running in batch (store results to csv) or "
188 + "running in a single instance (output the results)", 195 + "running in a single instance (output the results)",
189 ) 196 )
190 parser.add_argument( 197 parser.add_argument(