Model training
Model training is implemented in the fit(..) method. It takes the following parameters:
train_X:array_like, shape (n_samples, n_features), Training datatrain_Y:array_like, shape (n_samples, n_classes), Training labelsval_X:array_like, shape (N, n_features) optional, (default = None), Validation dataval_Y:array_like, shape (N, n_classes) optional, (default = None), Validation labelsgraph:tf.Graph, optional (default = None), TensorFlow Graph object
Next, we look at the implementation of fit(...) function where the model is trained and saved in the model path specified by model_path.
def fit(self, train_X, train_Y, val_X=None, val_Y=None, graph=None):
if len(train_Y.shape) != 1:
num_classes = train_Y.shape[1]
else:
raise Exception("Please convert the labels with one-hot encoding.")
g = graph if graph is not None else self.tf_graph
with g.as_default():
# Build model
self.build_model(train_X.shape[1], num_classes)
with...