Knowledge Distillation in Deep Learning -Keras Implementation

Moklesur Rahman
2 min readJan 25, 2023

Knowledge distillation is a technique used in deep learning to transfer the knowledge learned by a large, complex model (called the teacher model) to a smaller, simpler model (called the student model). The idea is to use the teacher model to “teach” the student model by providing it with the output of the teacher model, rather than just the input-output pairs used to train the teacher model. This allows the student model to learn from the teacher model’s expertise, and ultimately perform better than if it had been trained on its own. The process is done by minimizing the difference between the output of the two models, using a technique called “temperature scaling” to adjust the output of the teacher model to make it softer.

Figure: Knowledge Distillation

One way to implement knowledge distillation in Keras is to train the student model to mimic the output of the teacher model, rather than the ground-truth labels. This can be done by adding a new output layer to the student model, and using the output of the teacher model as the target for this new layer.

Here is some example code for training a student model using knowledge distillation in Keras:

from keras.layers import Input, Dense
from keras.models import Model

# Create the teacher model
inputs = Input(shape=(input_shape))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
teacher_outputs = Dense(num_classes, activation='softmax')(x)
teacher_model = Model(inputs, teacher_outputs)
teacher_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Create the student model
inputs = Input(shape=(input_shape))
x = Dense(32, activation='relu')(inputs)
x = Dense(32, activation='relu')(x)
student_outputs = Dense(num_classes, activation='softmax')(x)
student_model = Model(inputs, student_outputs)
student_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# Train the student model
teacher_outputs = teacher_model(inputs)
student_model.fit(x_train, teacher_outputs, epochs=10, batch_size=32, validation_data=(x_val, y_val))

This is a simplified example, you can use different types of layers, architectures, and loss functions. It’s important to note that the teacher model should be trained first, and then used to generate the targets for the…

--

--

Moklesur Rahman

PhD student | Computer Science | University of Milan | Data science | AI in Cardiology | Writer | Researcher