Transfer learning is a powerful technique where a model developed for
one task is reused as the starting point for a model on a second task.
It is especially popular in computer vision, where pre-trained models
like ResNet50
, which were trained on the massive ImageNet
dataset, can be used as powerful, ready-made feature extractors.
The kerasnip
package makes it easy to incorporate these
pre-trained Keras Applications directly into a tidymodels
workflow. This vignette will demonstrate how to:
kerasnip
model that uses a pre-trained
ResNet50
as a frozen base layer.tidymodels
workflow.First, we load the necessary packages.
library(kerasnip)
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
#> ✔ broom 1.0.8 ✔ recipes 1.3.0
#> ✔ dials 1.4.0 ✔ rsample 1.3.0
#> ✔ dplyr 1.1.4 ✔ tibble 3.2.1
#> ✔ ggplot2 3.5.2 ✔ tidyr 1.3.1
#> ✔ infer 1.0.8 ✔ tune 1.3.0
#> ✔ modeldata 1.4.0 ✔ workflows 1.2.0
#> ✔ parsnip 1.3.1 ✔ workflowsets 1.1.0
#> ✔ purrr 1.0.4 ✔ yardstick 1.3.2
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ dplyr::lag() masks stats::lag()
#> ✖ recipes::step() masks stats::step()
library(keras3)
#>
#> Attaching package: 'keras3'
#> The following object is masked from 'package:yardstick':
#>
#> get_weights
We’ll use the CIFAR-10 dataset, which consists of 60,000 32x32 color
images in 10 classes. keras3
provides a convenient function
to download it.
The ResNet50
model was pre-trained on ImageNet, which
has a different set of classes. Our goal is to fine-tune it to classify
the 10 classes in CIFAR-10.
# Load CIFAR-10 dataset
cifar10 <- dataset_cifar10()
# Separate training and test data
x_train <- cifar10$train$x
y_train <- cifar10$train$y
x_test <- cifar10$test$x
y_test <- cifar10$test$y
# Rescale pixel values from [0, 255] to [0, 1]
x_train <- x_train / 255
x_test <- x_test / 255
# Convert outcomes to factors for tidymodels
y_train_factor <- factor(y_train[, 1])
y_test_factor <- factor(y_test[, 1])
# For tidymodels, it's best to work with data frames.
# We'll use a list-column to hold the image arrays.
train_df <- tibble::tibble(
x = lapply(seq_len(nrow(x_train)), function(i) x_train[i, , , , drop = TRUE]),
y = y_train_factor
)
test_df <- tibble::tibble(
x = lapply(seq_len(nrow(x_test)), function(i) x_test[i, , , , drop = TRUE]),
y = y_test_factor
)
# Use a smaller subset for faster vignette execution
train_df_small <- train_df[1:500, ]
test_df_small <- test_df[1:100, ]
The standard approach for transfer learning is to use the Keras
Functional API. We will define a model where: 1. The base is a
pre-trained ResNet50
, with its final classification layer
removed (include_top = FALSE
). 2. The weights of the base
are frozen (trainable = FALSE
) so that only our new layers
are trained. 3. A new classification “head” is added, consisting of a
flatten layer and a dense output layer.
# Input block: shape is determined automatically from the data
input_block <- function(input_shape) {
layer_input(shape = input_shape)
}
# ResNet50 base block
resnet_base_block <- function(tensor) {
# The base model is not trainable; we use it for feature extraction.
resnet_base <- application_resnet50(
weights = "imagenet",
include_top = FALSE
)
resnet_base$trainable <- FALSE
resnet_base(tensor)
}
# New classification head
flatten_block <- function(tensor) {
tensor |> layer_flatten()
}
output_block_functional <- function(tensor, num_classes) {
tensor |> layer_dense(units = num_classes, activation = "softmax")
}
kerasnip
SpecificationWe connect these blocks using
create_keras_functional_spec()
.
Now we can use our new resnet_transfer()
specification
within a tidymodels
workflow.
spec_functional <- resnet_transfer(
fit_epochs = 5,
fit_validation_split = 0.2
) |>
set_engine("keras")
rec_functional <- recipe(y ~ x, data = train_df_small)
wf_functional <- workflow() |>
add_recipe(rec_functional) |>
add_model(spec_functional)
fit_functional <- fit(wf_functional, data = train_df_small)
# Evaluate on the test set
predictions <- predict(fit_functional, new_data = test_df_small)
#> 4/4 - 4s - 962ms/step
bind_cols(predictions, test_df_small) |>
accuracy(truth = y, estimate = .pred_class)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy multiclass 0.11
Even with a small dataset and few epochs, the pre-trained features from ResNet50 give us a reasonable starting point for accuracy.
This vignette demonstrated how kerasnip
bridges the
world of pre-trained Keras applications with the structured,
reproducible workflows of tidymodels
.
The Functional API is the most direct way to perform transfer learning by attaching a new head to a frozen base model.
This approach allows you to leverage the power of deep learning models that have been trained on massive datasets, significantly boosting performance on smaller, domain-specific tasks.