Source code for data_generator
from keras import backend as K
from library.image import ImageDataGenerator
K.set_image_dim_ordering('th')
[docs]class DataGenerator:
def __init__(self, time_delay=None):
self.time_delay = time_delay
self.images = None
self.labels = None
self.config_augmentation(time_delay=time_delay)
[docs] def config_augmentation(self, zca_whitening=False, rotation_angle=90, shift_range=0.2, horizontal_flip=True,
time_delay=None):
self.data_gen = ImageDataGenerator(featurewise_center=True,
featurewise_std_normalization=True,
zca_whitening=zca_whitening,
rotation_angle=rotation_angle,
width_shift_range=shift_range,
height_shift_range=shift_range,
horizontal_flip=horizontal_flip,
time_delay=time_delay)
return self
[docs] def fit(self, images, labels):
self._validate(images, labels)
self.images = images
self.labels = labels
self.data_gen.fit(self.images)
return self
[docs] def get_next_batch(self, batch_size=10, target_dimensions=None):
self._check_model_has_been_fit()
for images, labels in self.data_gen.flow(self.images, self.labels, batch_size=batch_size,
target_dimensions=target_dimensions):
return images, labels
[docs] def generate(self, target_dimensions=None, batch_size=10):
self._check_model_has_been_fit()
return self.data_gen.flow(self.images, self.labels, batch_size=batch_size, target_dimension=target_dimensions)
def _validate(self, images, labels):
if len(images) != len(labels):
raise ValueError("Samples are not labeled properly")
if images.ndim < 4:
raise ValueError("Channel Axis should have value")
if self.time_delay:
if images.ndim != 5:
raise ValueError("Time_delay parameter was set but Images say otherwise")
if images.ndim == 5 and images.shape[1] != self.time_delay:
raise ValueError("Images have time axis length {given} "
"but time_delay parameter was set to {set}"
.format(given=images.shape[1], set=self.time_delay))
def _check_model_has_been_fit(self):
if self.images is None or self.labels is None:
raise ValueError("Model is not fit to any data set yet")