sudo-paras-shah's picture
Remove async limitations
e916c8e
import os
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signal
import tensorflow as tf
import keras
class LossHistory(keras.callbacks.Callback):
def __init__(self, log_dir):
import datetime
curr_time = datetime.datetime.now()
time_str = datetime.datetime.strftime(curr_time,'%Y_%m_%d_%H_%M_%S')
self.log_dir = log_dir
self.time_str = time_str
self.save_path = os.path.join(self.log_dir, "loss_" + str(self.time_str))
self.losses = []
self.val_loss = []
os.makedirs(self.save_path)
def on_epoch_end(self, batch, logs={}):
self.losses.append(logs.get('loss'))
self.val_loss.append(logs.get('val_loss'))
with open(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".txt"), 'a') as f:
f.write(str(logs.get('loss')))
f.write("\n")
with open(os.path.join(self.save_path, "epoch_val_loss_" + str(self.time_str) + ".txt"), 'a') as f:
f.write(str(logs.get('val_loss')))
f.write("\n")
self.loss_plot()
def loss_plot(self):
iters = range(len(self.losses))
plt.figure()
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
try:
if len(self.losses) < 25:
num = 5
else:
num = 15
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
except:
pass
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('A Loss Curve')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".png"))
plt.cla()
plt.close("all")
class ExponentDecayScheduler(keras.callbacks.Callback):
def __init__(self,
decay_rate,
verbose=0):
super(ExponentDecayScheduler, self).__init__()
self.decay_rate = decay_rate
self.verbose = verbose
self.learning_rates = []
def on_epoch_end(self, batch, logs=None):
lr = self.model.optimizer.learning_rate
try:
current_lr = keras.backend.get_value(lr)
except Exception:
current_lr = lr
new_lr = current_lr * self.decay_rate
try:
keras.backend.set_value(lr, new_lr)
except Exception:
print("Warning: Could not set learning rate dynamically.")
if self.verbose > 0:
print('Setting learning rate to %s.' % (new_lr))