Siamese Neural Networks for One-shot Image Recognition

This paper from Koch et. al presents a technique to perform one shot learning with one or few examples per dataset.At a very high level, the idea appears similar to K-Nearest Neighbours except they use features extracted by Convolutional layers instead of direct pixel values. But that is probably an over simplification, the paper shows how to construct a model that jointly trains two identical convolutional neural network to differentiate between images from different categories.Once the network is trained, it can be used on entirely new classes from unknown distribution.
fc 1mergefc 40964x4 conv, 2564x4 conv, 1287x7 conv, 12810x10 conv, 64fc 40964x4 conv, 2564x4 conv, 1287x7 conv, 12810x10 conv, 64
Proposed network architecture
The code walkthrough covers my approach to dataset preparation, network implementation and test results.
# Comment next line to see the warnings
import warnings; warnings.simplefilter('ignore')

from glob import glob
import numpy as np
import cv2
# Pandas is only used to display data in nice tabular format
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import HTML, display

from keras.models import Sequential, Model
from keras.layers import Dense, Conv2D, Input, MaxPool2D, Flatten, merge
from keras.preprocessing.image import ImageDataGenerator
from keras.regularizers import l2
from keras.initializers import RandomNormal
from keras.optimizers import Adam
import keras.backend as K
Using TensorFlow backend.
# Turn debugger on/off
%pdb

Terms used in the paper

alphabet: sub-dataset representing characters that belong in same catergory. Each alphabet has N characters and each character has 20 examples

n-way one-shot learning: e.g 20-way meaning 20 distinct characters from an alphabet

drawer: Number of draws (e.g 2 distinct drawers of 20 randomly selected 20 characters of an alphabet)

Dataset: Omniglot

We use omniglot dataset, that contains 1623 different handwritten characters from 50 different alphabets, to train a CNN that attempts to identify the alphabets closest to the input character.

The image below shows the images from each alphabet in training set. The code is available here

Dataset Preparation

We prepare training and validation set from images available in images_background.zip under https://github.com/brendenlake/omniglot/tree/master/python

dataset_dirname = '/mnt/bigdrive/datasets/omniglot/python'
alphabets = glob('{}/images_background/*'.format(dataset_dirname))
# create train and validation set
dataset={}
for alph in alphabets:
    dirs = glob('{}/*'.format(alph))
    alphabet = {}
    for dirname in dirs:
        alphabet[dirname] = glob('{}/*'.format(dirname))
    dataset[alph] = alphabet        
# Display number of distinct characters in each alphabet
d = {k.split('/')[-1]: len(v) for k,v in dataset.items()}
df = pd.DataFrame(list(d.items()), columns=['Alphabet', 'Distinct characters'])
display(HTML(df.to_html()))
Alphabet Distinct characters
0 Braille 26
1 Tifinagh 55
2 Tagalog 17
3 Alphabet_of_the_Magi 20
4 Bengali 46
5 Gujarati 48
6 Inuktitut_(Canadian_Aboriginal_Syllabics) 16
7 Armenian 41
8 Mkhedruli_(Georgian) 41
9 Hebrew 22
10 Early_Aramaic 22
11 Asomtavruli_(Georgian) 40
12 Malay_(Jawi_-_Arabic) 40
13 Ojibwe_(Canadian_Aboriginal_Syllabics) 14
14 Burmese_(Myanmar) 34
15 Latin 26
16 Korean 40
17 Cyrillic 33
18 Futurama 26
19 Greek 24
20 Arcadian 26
21 Japanese_(katakana) 47
22 Anglo-Saxon_Futhorc 29
23 N_Ko 33
24 Grantha 43
25 Japanese_(hiragana) 52
26 Balinese 24
27 Syriac_(Estrangelo) 23
28 Blackfoot_(Canadian_Aboriginal_Syllabics) 14
29 Sanskrit 42

All 30 alphabets from images_background folder are used in training and validation dataset. Each character from alphabet contains 20 distinct images. We keep 12 images in training set and 4 for validation.

IMAGE_SHAPE = (105, 105, 1)

# make train set
trainset = {}
for cat, alphabet_dict in dataset.items():
    alphabet = {}
    for subcat, fnames in alphabet_dict.items():
        alphabet[subcat] = np.random.choice(fnames, 12)
    trainset[cat] = alphabet
# make valid set
validset = {}
for cat, alphabet_dict in dataset.items():
    alphabet = {}
    for subcat, fnames in alphabet_dict.items():
        train_fnames = trainset[cat][subcat]
        valid_fnames = list(set(fnames).difference(set(train_fnames)))
        alphabet[subcat] = np.random.choice(valid_fnames, 4)
    validset[cat] = alphabet
# Validate the files in valid set do not exist in trainset
for alphabet, alphabet_dict in dataset.items():
    for char in alphabet_dict.keys():
        assert(len(set(trainset[alphabet][char]).intersection(set(validset[alphabet][char]))) == 0)
# Display train and validset setup (12 images in train and 4 in valid)
alphabet = np.random.choice(list(dataset.keys()))
df = pd.DataFrame(data = {
    'trainset': [(key.split('/')[-2:],len(value)) for key, value in trainset[alphabet].items()],
    'validset': [ (key.split('/')[-2:],len(value)) for key, value in validset[alphabet].items()]})
display(HTML(df.to_html()))
trainset validset
0 ([Alphabet_of_the_Magi, character02], 12) ([Alphabet_of_the_Magi, character02], 4)
1 ([Alphabet_of_the_Magi, character15], 12) ([Alphabet_of_the_Magi, character15], 4)
2 ([Alphabet_of_the_Magi, character19], 12) ([Alphabet_of_the_Magi, character19], 4)
3 ([Alphabet_of_the_Magi, character07], 12) ([Alphabet_of_the_Magi, character07], 4)
4 ([Alphabet_of_the_Magi, character04], 12) ([Alphabet_of_the_Magi, character04], 4)
5 ([Alphabet_of_the_Magi, character17], 12) ([Alphabet_of_the_Magi, character17], 4)
6 ([Alphabet_of_the_Magi, character08], 12) ([Alphabet_of_the_Magi, character08], 4)
7 ([Alphabet_of_the_Magi, character05], 12) ([Alphabet_of_the_Magi, character05], 4)
8 ([Alphabet_of_the_Magi, character06], 12) ([Alphabet_of_the_Magi, character06], 4)
9 ([Alphabet_of_the_Magi, character10], 12) ([Alphabet_of_the_Magi, character10], 4)
10 ([Alphabet_of_the_Magi, character01], 12) ([Alphabet_of_the_Magi, character01], 4)
11 ([Alphabet_of_the_Magi, character11], 12) ([Alphabet_of_the_Magi, character11], 4)
12 ([Alphabet_of_the_Magi, character16], 12) ([Alphabet_of_the_Magi, character16], 4)
13 ([Alphabet_of_the_Magi, character03], 12) ([Alphabet_of_the_Magi, character03], 4)
14 ([Alphabet_of_the_Magi, character09], 12) ([Alphabet_of_the_Magi, character09], 4)
15 ([Alphabet_of_the_Magi, character20], 12) ([Alphabet_of_the_Magi, character20], 4)
16 ([Alphabet_of_the_Magi, character18], 12) ([Alphabet_of_the_Magi, character18], 4)
17 ([Alphabet_of_the_Magi, character13], 12) ([Alphabet_of_the_Magi, character13], 4)
18 ([Alphabet_of_the_Magi, character12], 12) ([Alphabet_of_the_Magi, character12], 4)
19 ([Alphabet_of_the_Magi, character14], 12) ([Alphabet_of_the_Magi, character14], 4)

Training

During training, the network is shown pairs of matching and non-matching images from a randomly selected alphabet.

Example:

  • matching pair: (alphabet1 -> character1 -> img001.jpg) (alphabet1 -> character1 -> img002.jpg)
  • non-matching pair: (alphabet1 -> character1 -> img001.jpg) (alphabet1 -> character3 -> img002.jpg)
# Generate training data
def generate_oneshot_trials(dataset, batch_size=32, num_chars=20):
    data_1 = []    
    data_2 = []
    targets = []
    for alphabet_index in np.random.randint(0, len(dataset), batch_size//2):
        alphabet = list(dataset.keys())[alphabet_index]
        alphabet_items = list(dataset[alphabet].items())
        characters = np.random.randint(0, len(alphabet_items), num_chars)

        # we kept 12 samples from each character in trainset and 4 in validset
        num_samples = len(alphabet_items[0][1])        
        drawers = np.random.randint(0, num_samples, 2)
        # create two entries of same character/filename for both matching and non-matching pairs
        img_list = [alphabet_items[characters[0]][1][drawers[0]]] * 2
        data_1.extend(img_list)
        targets.append(1)
        
        data_2.append(alphabet_items[characters[0]][1][drawers[1]])
        data_2.append(alphabet_items[characters[1]][1][drawers[1]])
        targets.append(0)
                
    return [data_1, data_2], targets
x_train, y_train = generate_oneshot_trials(trainset, batch_size=10)
# display train data
fig, axes = plt.subplots(10, 3, figsize=(8,8))
for i in range(10):
    axes[i, 0].imshow(cv2.imread(x_train[0][i], cv2.IMREAD_GRAYSCALE), cmap='gray')
    axes[i, 0].axis('off')
    axes[i, 1].imshow(cv2.imread(x_train[1][i], cv2.IMREAD_GRAYSCALE), cmap='gray')
    axes[i, 1].axis('off')
    axes[i, 2].text(0.5, 0.5, 'same' if y_train[i] == 1 else 'different')
    axes[i, 2].axis('off')
    
plt.show()

png

Validation

During validation, n-way (e.g 20-way) samples are generated by drawing n distinct characters from a randomly selected alphabet.

E.g 2 drawers of 20-way samples:

  1. choose random alphabet
  2. select 20 distinct characters from alphabet
  3. generate 2 distinct image filenames
# Generate validation set where a single image is tested against n_way images
def generate_oneshot_validation_trials(dataset, n_way=20):
    idx = np.random.randint(0, len(dataset), 1)[0]
    alphabet = list(dataset.keys())[idx]
    alphabet_items = list(dataset[alphabet].items())
    
    char_indices = np.random.randint(0, len(alphabet_items), 20)
    # we kept 12 samples from each character in trainset and 4 in validset
    num_samples = len(alphabet_items[0][1])
    drawers = np.random.randint(0, num_samples, 2)

    
    test_images = []
    train_images = []
    for i in range(n_way):
        test_images.append(alphabet_items[char_indices[i]][1][drawers[0]])
        train_images.append(alphabet_items[char_indices[i]][1][drawers[1]])
    
    return (test_images, train_images)
# load images from disk
def get_images(fnames, shape):
    batch = np.empty((len(fnames), *shape), dtype=np.uint8)
    for i, fname in enumerate(fnames):
        batch[i,...] = np.expand_dims(cv2.imread(fname, cv2.IMREAD_GRAYSCALE), axis=2)
    
    return batch

The training and validation set contains only the filenames, the following code block loads the images from disk into numpy array (batch_size, width, height, channel).

# Load images from the list of paired image filenames
def get_paired_images(fnames, shape):
    batch_left = np.empty((len(fnames[0]), *shape), dtype=np.uint8)
    batch_right = np.empty((len(fnames[1]), *shape), dtype=np.uint8)
    for i, fname in enumerate(zip(fnames[0], fnames[1])):
        batch_left[i,...] = np.expand_dims(cv2.imread(fname[0], cv2.IMREAD_GRAYSCALE), axis=2)
        batch_right[i,...] = np.expand_dims(cv2.imread(fname[1], cv2.IMREAD_GRAYSCALE), axis=2)
    
    return [batch_left, batch_right]

The code to validate model on unseen dataset and visualise the results predicted by model. This will be used during training and evaluation steps.

# Validate model by testing it against unseen dataset 
def validate_model(model, dataset=validset, num_iters=1):
    # do it for n images
    total_corrects = 0
    for n in range(num_iters):
        (test_image_names, train_image_names) = generate_oneshot_validation_trials(dataset, 20)
        test_images = get_images(test_image_names, IMAGE_SHAPE)
        train_images = get_images(train_image_names, IMAGE_SHAPE)    
        
        corrects = 0
        for i in range(len(test_images)):
            images = np.tile(test_images[i], (len(train_images), 1, 1, 1))
            preds = model.predict([images, train_images])
            pred_idx = np.argmax(preds, axis=0)[0]
            # check they belong to same character
            test_char_name = test_image_names[i].split('/')[-2]
            pred_char_name = train_image_names[pred_idx].split('/')[-2]
            corrects += 1 if test_char_name == pred_char_name else 0
        
        total_corrects += corrects/len(train_image_names)
    
    return total_corrects/num_iters
# Display test character along with n-way characters, predicted as well as correct labels
def display_validation_test(model, dataset=validset):
    (test_image_names, train_image_names) = generate_oneshot_validation_trials(dataset, 20)
    test_images = get_images(test_image_names, IMAGE_SHAPE)
    train_images = get_images(train_image_names, IMAGE_SHAPE)
    
    idx = np.random.randint(0, len(test_images))
    
    images = np.tile(test_images[idx], (len(train_images), 1, 1, 1))
    preds = model.predict([images, train_images])
    pred_idx = np.argmax(preds, axis=0)[0]
    print('Alphabet: {}'.format(test_image_names[idx].split('/')[-3]))
    
    # check they belong to same character
    test_char_name = test_image_names[idx].split('/')[-2]
    pred_char_name = train_image_names[pred_idx].split('/')[-2]
    correct = 1 if test_char_name == pred_char_name else 0

    # reshape to (w,h) and display
    plt.imshow(images[idx].reshape((IMAGE_SHAPE[0], IMAGE_SHAPE[1])), cmap='gray')
    plt.axis('off')
    plt.annotate(test_image_names[idx].split('/')[-2], (0.0, 0.0), textcoords='axes fraction', size=10)
    plt.show()

    print('*'*100)
    print('Which of the following match above?')
    fig, axes = plt.subplots(4, 5, figsize=(10,10))
    ctr = 0
    print('Predicted index: {}, name: {}, actual: {}'.format(pred_idx, pred_char_name, test_char_name))
    
    for i in range(4):
        for j in range(5):
            class_name = train_image_names[ctr].split('/')[-2]
            axes[i, j].imshow(train_images[ctr].reshape((IMAGE_SHAPE[0], IMAGE_SHAPE[1])), cmap='gray')
            axes[i, j].annotate(class_name, (0.0, 0.0), textcoords='axes fraction', size=10)
            axes[i, j].axis('off')                         
            ctr += 1
    
    # Highlight prediction and correct label
    col = pred_idx%5
    row = pred_idx//5
    if correct:
        axes[row, col].text(0, 5, 'Success', style='italic', 
                        bbox={'facecolor':'deepskyblue', 'alpha':0.5, 'pad':10})
    else:
        axes[row, col].text(0, 5, 'Predicted', style='italic', 
                        bbox={'facecolor':'red', 'alpha':0.5, 'pad':10})        
        correct_indices = [idx for idx, fname in enumerate(train_image_names) if fname.split('/')[-2] == test_char_name]
        for idx in correct_indices:
            c = idx%5
            r = idx//5
            axes[r, c].text(0, 5, 'Ground truth', style='italic', 
                        bbox={'facecolor':'aquamarine', 'alpha':0.5, 'pad':10})

    plt.show()    

Siamese Network

The following image from paper shows the network architecture of one of the twin. They both use the same network architecture, hence siamese :)

Siamese Network 1

We start with network architecture mentioned in the paper but without weight and bias initialization, l2 regularization. This will helps us understand their impact on model performance

# Convolutional Neural Network
def conv_model(input_shape):
    input = Input(shape=input_shape)
    x = Conv2D(64, (10, 10), activation='relu')(input)
    x = MaxPool2D(2,2)(x)    
    x = Conv2D(128, (7, 7), activation='relu')(x)        
    x = MaxPool2D((2,2))(x)                  
    x = Conv2D(128, (4, 4), activation='relu')(x)
    x = MaxPool2D((2,2))(x)                  
    x = Conv2D(256, (4, 4), activation='relu')(x)                  
    x = Flatten()(x)
    x = Dense(4096, activation='sigmoid')(x)
    
    return Model(input, x)
# Siamese Network
def build_network(conv_model):
    # Build two networks
    input_shape = (105, 105, 1)
    input1 = Input(input_shape)
    input2 = Input(input_shape)
    
    model = conv_model(input_shape)

    model_output_left = model(input1)
    model_output_right = model(input2)

    def l1_distance(x): 
        return K.abs(x[0] - x[1])

    def l1_distance_shape(x): 
        print(x)
        return x[0]

    # Merge is for layers, merge is for tensors

    # merged_model = Merge([model1, model2], mode=l1_distance, output_shape=lambda x: x[0])
    merged_model = merge([model_output_left, model_output_right], mode=l1_distance, output_shape=l1_distance_shape)
    output = Dense(1, activation='sigmoid')(merged_model)
    siamese_model = Model([input1, input2], output)
    return siamese_model
# Build and compile
siamese_model1 = build_network(conv_model)
siamese_model1.compile(loss='binary_crossentropy', optimizer=Adam(0.00006), metrics=['acc'])
siamese_model1.summary()
[(None, 4096), (None, 4096)]
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_10 (InputLayer)           (None, 105, 105, 1)  0                                            
__________________________________________________________________________________________________
input_11 (InputLayer)           (None, 105, 105, 1)  0                                            
__________________________________________________________________________________________________
model_7 (Model)                 (None, 4096)         38947648    input_10[0][0]                   
                                                                 input_11[0][0]                   
__________________________________________________________________________________________________
merge_4 (Merge)                 (None, 4096)         0           model_7[1][0]                    
                                                                 model_7[2][0]                    
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 1)            4097        merge_4[0][0]                    
==================================================================================================
Total params: 38,951,745
Trainable params: 38,951,745
Non-trainable params: 0
__________________________________________________________________________________________________
siamese_model1.metrics_names
['loss', 'acc']

Train siamese network 1

%time 
image_shape = (105, 105, 1)
# hyperparams
epochs = 20000
batch_size = 128
print_every = int(.1 * epochs)
validate_every = int(.1 * epochs)
best_val = 0.0

for i in range(epochs):
    (fnames, targets) = generate_oneshot_trials(trainset, batch_size)
    batches = get_paired_images(fnames, image_shape)  
    hist = siamese_model1.train_on_batch(batches, targets)
    if i % print_every == 0:
        print('{} train loss:{}, accuracy:{}'.format(i, *hist))
        
    if i % validate_every == 0:
        acc = validate_model(siamese_model1, num_iters=30)
        print('Validation acc: {}'.format(acc))
        if acc > best_val:
            weights_fname = 'weights/siamese_n1.hdf5'
            siamese_model1.save_weights(weights_fname)
            print('Save model: {}'.format(weights_fname))
            best_val = acc
CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 14.3 µs
0 train loss:0.7143945693969727, accuracy:0.4296875
Validation acc: 0.4783333333333333
Save model: weights/siamese_n1.hdf5
2000 train loss:0.3331770598888397, accuracy:0.8828125
Validation acc: 0.6433333333333334
Save model: weights/siamese_n1.hdf5
4000 train loss:0.2253703624010086, accuracy:0.953125
Validation acc: 0.7433333333333333
Save model: weights/siamese_n1.hdf5
6000 train loss:0.09980868548154831, accuracy:0.9765625
Validation acc: 0.7033333333333333
8000 train loss:0.10853322595357895, accuracy:0.984375
Validation acc: 0.7783333333333332
Save model: weights/siamese_n1.hdf5
10000 train loss:0.13877654075622559, accuracy:0.9609375
Validation acc: 0.8400000000000001
Save model: weights/siamese_n1.hdf5
12000 train loss:0.18388020992279053, accuracy:0.96875
Validation acc: 0.8216666666666667
14000 train loss:0.1518031358718872, accuracy:0.9609375
Validation acc: 0.8333333333333333
16000 train loss:0.16200870275497437, accuracy:0.9609375
Validation acc: 0.7999999999999999
18000 train loss:0.11673762649297714, accuracy:0.984375
Validation acc: 0.8266666666666665

Evaluate using validation set

siamese_model1.load_weights('weights/siamese_n1.hdf5')
display_validation_test(siamese_model1)
Alphabet: Malay_(Jawi_-_Arabic)

png

****************************************************************************************************
Which of the following match above?
Predicted index: 4, name: character38, actual: character38

png

val_acc = validate_model(siamese_model1, num_iters=30)
print('Validation accuracy is: {:.2f}'.format(val_acc))
Validation accuracy is: 0.81

Siamese Network 2

The architecture remains the same but this time we add the weight and bias initialization, l2 regularization as described in the paper. Note: I did not find recommended regularization strength in the paper and this is based on blog referenced below.

# def weight_initializer(shape, name=None):
#     w = np.random.normal(loc=0, scale=1e-2, size=shape)
#     return K.variable(w, name=name)

# def bias_initializer(shape, name=None):
#     b = np.random.normal(loc=0.5, scale=1e-2, size=shape)
#     return K.variable(b, name=name)
#
# def conv_model_with_initialisers(input_shape):
#     input = Input(shape=input_shape)
#     x = Conv2D(64, (10, 10), activation='relu', kernel_initializer=weight_initializer,
#                kernel_regularizer=l2(2e-4))(input)
#     x = MaxPool2D(2,2)(x)    
#     x = Conv2D(128, (7, 7), activation='relu', kernel_initializer=weight_initializer, 
#                kernel_regularizer=l2(2e-4))(x)        
#     x = MaxPool2D((2,2))(x)                  
#     x = Conv2D(128, (4, 4), activation='relu', kernel_initializer=weight_initializer, 
#                kernel_regularizer=l2(2e-4), bias_initializer=bias_initializer)(x)                  
#     x = MaxPool2D((2,2))(x)                  
#     x = Conv2D(256, (4, 4), activation='relu', kernel_initializer=weight_initializer, 
#                kernel_regularizer=l2(2e-4), bias_initializer=bias_initializer)(x)                  
#     x = Flatten()(x)
#     x = Dense(4096, activation='sigmoid', kernel_initializer=weight_initializer, 
#               kernel_regularizer=l2(2e-3), bias_initializer=bias_initializer)(x)
    
#     return Model(input, x)
# create conv model with l2 regularization, weights and bias initialization as specified in the paper.
def conv_model_with_initialisers(input_shape):
    input = Input(shape=input_shape)
    x = Conv2D(64, (10, 10), activation='relu', 
               kernel_initializer=RandomNormal(0.0, 1e-2),
               bias_initializer=RandomNormal(0.5, 1e-2),
               kernel_regularizer=l2(2e-4))(input)
    x = MaxPool2D(2,2)(x)    
    x = Conv2D(128, (7, 7), activation='relu', 
               kernel_initializer=RandomNormal(0.0, 1e-2), 
               bias_initializer=RandomNormal(0.5, 1e-2),               
               kernel_regularizer=l2(2e-4))(x)        
    x = MaxPool2D((2,2))(x)                  
    x = Conv2D(128, (4, 4), activation='relu', 
               kernel_initializer=RandomNormal(0.0, 1e-2),
               bias_initializer=RandomNormal(0.5, 1e-2),
               kernel_regularizer=l2(2e-4))(x)                  
    x = MaxPool2D((2,2))(x)                  
    x = Conv2D(256, (4, 4), activation='relu', 
               kernel_initializer=RandomNormal(0.0, 1e-2),
               bias_initializer=RandomNormal(0.5, 1e-2), 
               kernel_regularizer=l2(2e-4))(x)                  
    x = Flatten()(x)
    x = Dense(4096, activation='sigmoid', 
              kernel_initializer=RandomNormal(0.0, 0.2),
              bias_initializer=RandomNormal(0.5, 1e-2),
              kernel_regularizer=l2(2e-3))(x)
    
    return Model(input, x)
# Build and compile
siamese_model2 = build_network(conv_model_with_initialisers)
siamese_model2.compile(loss='binary_crossentropy', optimizer=Adam(0.00006), metrics=['acc'])
siamese_model2.summary()
[(None, 4096), (None, 4096)]
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_7 (InputLayer)            (None, 105, 105, 1)  0                                            
__________________________________________________________________________________________________
input_8 (InputLayer)            (None, 105, 105, 1)  0                                            
__________________________________________________________________________________________________
model_5 (Model)                 (None, 4096)         38947648    input_7[0][0]                    
                                                                 input_8[0][0]                    
__________________________________________________________________________________________________
merge_3 (Merge)                 (None, 4096)         0           model_5[1][0]                    
                                                                 model_5[2][0]                    
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 1)            4097        merge_3[0][0]                    
==================================================================================================
Total params: 38,951,745
Trainable params: 38,951,745
Non-trainable params: 0
__________________________________________________________________________________________________

Train second network

#2
%time 
image_shape = (105, 105, 1)
# hyperparams
epochs = 20000
batch_size = 128
print_every = int(.1 * epochs)
validate_every = int(.1 * epochs)
best_val = 0.0

for i in range(epochs):
    (fnames, targets) = generate_oneshot_trials(trainset, batch_size)
    batches = get_paired_images(fnames, image_shape)  
    hist = siamese_model2.train_on_batch(batches, targets)
    if i % print_every == 0:
        print('{} train loss:{}, accuracy:{}'.format(i, *hist))
        
    if i % validate_every == 0:
        acc = validate_model(siamese_model2, num_iters=30)
        print('Validation acc: {}'.format(acc))
        if acc > best_val:
            weights_fname = 'weights/siamese_n2.hdf5'
            siamese_model2.save_weights(weights_fname)
            print('Save model: {}'.format(weights_fname))
            best_val = acc
CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 13.4 µs
0 train loss:3021.04736328125, accuracy:0.4765625
Validation acc: 0.275
Save model: weights/siamese_n2.hdf5
2000 train loss:1180.6741943359375, accuracy:0.796875
Validation acc: 0.28833333333333333
Save model: weights/siamese_n2.hdf5
4000 train loss:428.400634765625, accuracy:0.9140625
Validation acc: 0.5800000000000002
Save model: weights/siamese_n2.hdf5
6000 train loss:134.56378173828125, accuracy:0.921875
Validation acc: 0.655
Save model: weights/siamese_n2.hdf5
8000 train loss:35.539493560791016, accuracy:0.9375
Validation acc: 0.7449999999999999
Save model: weights/siamese_n2.hdf5
10000 train loss:7.924860000610352, accuracy:0.9453125
Validation acc: 0.7366666666666666
12000 train loss:1.6623940467834473, accuracy:0.9453125
Validation acc: 0.8083333333333332
Save model: weights/siamese_n2.hdf5
14000 train loss:0.5135393142700195, accuracy:0.953125
Validation acc: 0.825
Save model: weights/siamese_n2.hdf5
16000 train loss:0.23554538190364838, accuracy:0.9765625
Validation acc: 0.8616666666666666
Save model: weights/siamese_n2.hdf5
18000 train loss:0.3464904725551605, accuracy:0.921875
Validation acc: 0.8516666666666665
siamese_model2.load_weights('weights/siamese_n2.hdf5')
display_validation_test(siamese_model2)
Alphabet: Early_Aramaic

png

****************************************************************************************************
Which of the following match above?
Predicted index: 1, name: character17, actual: character17

png

val_acc = validate_model(siamese_model2, num_iters=30)
print('Validation accuracy is: {:.2f}'.format(val_acc))
Validation accuracy is: 0.87

Test

We test the model with best validation accuracy on test set. This set includes previously unseen alphabets and characters.

test_alphabets = glob('{}/images_evaluation/*'.format(dataset_dirname))
# create train and validation set
testset={}
for alph in test_alphabets:
    dirs = glob('{}/*'.format(alph))
    alphabet = {}
    for dirname in dirs:
        alphabet[dirname] = glob('{}/*'.format(dirname))
    testset[alph] = alphabet        
# Lets give it a spin and see how the model performs on unknown + unseen data
display_validation_test(siamese_model2, testset)
Alphabet: Syriac_(Serto)

png

****************************************************************************************************
Which of the following match above?
Predicted index: 5, name: character20, actual: character20

png

val_acc = validate_model(siamese_model2, dataset=testset, num_iters=30)
print('Network 2 accuracy on testset is: {:.2f}'.format(val_acc))
Network 2 accuracy on testset is: 0.67
val_acc = validate_model(siamese_model1, dataset=testset, num_iters=30)
print('Network 2 accuracy on testset is: {:.2f}'.format(val_acc))
Network 2 accuracy on testset is: 0.65

Conclusion

The initialization and regularization from the paper does not seem to make a big difference in model performance. The next step will to add data augmentation steps to help model generalize better.

Ideas to try:

  1. Instead of learning from scratch we could make use of transfer learning (i.e pre-trained network on imagenet), that should allow network to train faster as it doesn’t need to learn from scratch

  2. Even with above architecture, we would use batchnorm, cyclic learning rates to train the network faster.

References:

I found following resources useful during my implementation:

  1. Siamese Neural Networks for One-Shot Image Recognition Paper
  2. Gregory Koch’s Thesis, I found it useful to understand how train and validation set were created
  3. Blog that implements the paper and explains the approach in detail

originally published 20 Mar 2018 and updated 30 Mar 2018