Trains a simple CNN-Capsule Network on the CIFAR10 small images dataset
Jump to navigation
Jump to search
The printable version is no longer supported and may have rendering errors. Please update your browser bookmarks and please use the default browser print function instead.
Sumber: https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn_capsule.py
# This example trains a simple CNN-Capsule Network on the CIFAR10 data set. # Without Data Augmentation: # It gets to 75% validation accuracy in 10 epochs, 79% after 15 epochs, # and overfitting after 20 epochs # With Data Augmentation: # It gets to 75% validation accuracy in 10 epochs, 79% after 15 epochs, # and 83% after 30 epochs. # The highest achieved validation accuracy is 83.79% after 50 epochs. # This is a fast implementation that takes just 20s/epoch on a GTX 1070 GPU. # The paper "Dynamic Routing Between Capsules": https://arxiv.org/abs/1710.09829 from __future__ import print_function from keras import activations from keras import backend as K from keras import layers from keras import utils from keras.datasets import cifar10 from keras.models import Model from keras.preprocessing.image import ImageDataGenerator def squash(x, axis=-1): """The Squashing Function. The nonlinear activation function used in Capsule Network # Arguments x: Input Tensor. axis: Integer axis along which the squashing function is to be applied. # Returns Tensor with scaled value of the input tensor """ s_squared_norm = K.sum(K.square(x), axis, keepdims=True) + K.epsilon() scale = K.sqrt(s_squared_norm) / (0.5 + s_squared_norm) return scale * x def margin_loss(y_true, y_pred): """Margin loss # Arguments y_true: tensor of true targets. y_pred: tensor of predicted targets. # Returns Tensor with one scalar loss entry per sample. """ lamb, margin = 0.5, 0.1 return K.sum(y_true * K.square(K.relu(1 - margin - y_pred)) + lamb * ( 1 - y_true) * K.square(K.relu(y_pred - margin)), axis=-1)
class Capsule(layers.Layer): """Capsule Network A Capsule Network Layer implementation in Keras There are two versions of Capsule Networks. One is similar to dense layer (for the fixed-shape input), and the other is similar to time distributed dense layer (for inputs of varied length). The input shape of Capsule must be (batch_size, input_num_capsule, input_dim_capsule ) and the output shape is (batch_size, num_capsule, dim_capsule ) The Capsule implementation is from https://github.com/bojone/Capsule/ # Arguments num_capsule: An integer, the number of capsules. dim_capsule: An integer, the dimensions of the capsule. routings: An integer, the number of routings. share_weights: A boolean, sets weight sharing between layers. activation: A string, the activation function to be applied. """ def __init__(self, num_capsule, dim_capsule, routings=3, share_weights=True, activation='squash', **kwargs): super(Capsule, self).__init__(**kwargs) self.num_capsule = num_capsule self.dim_capsule = dim_capsule self.routings = routings self.share_weights = share_weights if activation == 'squash': self.activation = squash else: self.activation = activations.get(activation) def build(self, input_shape): input_dim_capsule = input_shape[-1] if self.share_weights: self.kernel = self.add_weight( name='capsule_kernel', shape=(1, input_dim_capsule, self.num_capsule * self.dim_capsule), initializer='glorot_uniform', trainable=True) else: input_num_capsule = input_shape[-2] self.kernel = self.add_weight( name='capsule_kernel', shape=(input_num_capsule, input_dim_capsule, self.num_capsule * self.dim_capsule), initializer='glorot_uniform', trainable=True) def call(self, inputs, **kwargs): """Following the routing algorithm from Hinton's paper, but replace b = b + <u,v> with b = <u,v>. This change can improve the feature representation of the capsule. However, you can replace b = K.batch_dot(outputs, hat_inputs, [2, 3]) with b += K.batch_dot(outputs, hat_inputs, [2, 3]) to get standard routing. """ if self.share_weights: hat_inputs = K.conv1d(inputs, self.kernel) else: hat_inputs = K.local_conv1d(inputs, self.kernel, [1], [1]) batch_size = K.shape(inputs)[0] input_num_capsule = K.shape(inputs)[1] hat_inputs = K.reshape(hat_inputs, (batch_size, input_num_capsule, self.num_capsule, self.dim_capsule)) hat_inputs = K.permute_dimensions(hat_inputs, (0, 2, 1, 3)) b = K.zeros_like(hat_inputs[:, :, :, 0]) print(self.routings) for i in range(self.routings): c = K.softmax(b, 1) o = self.activation(K.batch_dot(c, hat_inputs, [2, 2])) if i < self.routings - 1: b = K.batch_dot(o, hat_inputs, [2, 3]) if K.backend() == 'theano': o = K.sum(o, axis=1) return o def compute_output_shape(self, input_shape): return None, self.num_capsule, self.dim_capsule
batch_size = 128 num_classes = 10 epochs = 100 (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 y_train = utils.to_categorical(y_train, num_classes) y_test = utils.to_categorical(y_test, num_classes) # A simple Conv2D model input_image = layers.Input(shape=(None, None, 3)) x = layers.Conv2D(64, (3, 3), activation='relu')(input_image) x = layers.Conv2D(64, (3, 3), activation='relu')(x) x = layers.AveragePooling2D((2, 2))(x) x = layers.Conv2D(128, (3, 3), activation='relu')(x) x = layers.Conv2D(128, (3, 3), activation='relu')(x) # Now, we reshape it to (batch_size, input_num_capsule, input_dim_capsule) # then connect a capsule layer. # The output of final model is the lengths of 10 capsules, which have 16 dimensions. # The length of the output vector of the capsule expresses the probability of # existence of the entity, so the problem becomes a 10 two-classification problem.
x = layers.Reshape((-1, 128))(x) capsule = Capsule(10, 16, 3, True)(x) output = layers.Lambda(lambda z: K.sqrt(K.sum(K.square(z), 2)))(capsule) model = Model(inputs=input_image, outputs=output) # Margin loss is used model.compile(loss=margin_loss, optimizer='adam', metrics=['accuracy']) model.summary() # Compare the performance with and without data augmentation data_augmentation = True if not data_augmentation: print('Not using data augmentation.') model.fit( x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test), shuffle=True) else: print('Using real-time data augmentation.') # This will do preprocessing and real-time data augmentation: datagen = ImageDataGenerator( featurewise_center=False, # set input mean to 0 over the dataset samplewise_center=False, # set each sample mean to 0 featurewise_std_normalization=False, # divide inputs by dataset std samplewise_std_normalization=False, # divide each input by its std zca_whitening=False, # apply ZCA whitening zca_epsilon=1e-06, # epsilon for ZCA whitening rotation_range=0, # randomly rotate images in 0 to 180 degrees width_shift_range=0.1, # randomly shift images horizontally height_shift_range=0.1, # randomly shift images vertically shear_range=0., # set range for random shear zoom_range=0., # set range for random zoom channel_shift_range=0., # set range for random channel shifts # set mode for filling points outside the input boundaries fill_mode='nearest', cval=0., # value used for fill_mode = "constant" horizontal_flip=True, # randomly flip images vertical_flip=False, # randomly flip images # set rescaling factor (applied before any other transformation) rescale=None, # set function that will be applied on each input preprocessing_function=None, # image data format, either "channels_first" or "channels_last" data_format=None, # fraction of images reserved for validation (strictly between 0 and 1) validation_split=0.0) # Compute quantities required for feature-wise normalization # (std, mean, and principal components if ZCA whitening is applied). datagen.fit(x_train) # Fit the model on the batches generated by datagen.flow(). model.fit_generator( datagen.flow(x_train, y_train, batch_size=batch_size), epochs=epochs, validation_data=(x_test, y_test), workers=4)