Ensembles of data-efficient vision transformers as a new paradigm for automated classification in ecology
DataWe tested our models on ten publicly available datasets. In Fig. 4 we show examples of images from each of the datasets. When applicable, the training and test splits were kept the same as in the original dataset. For example, the ZooScan, Kaggle, EILAT, and RSMAS datasets lack a specific training and test set; in these cases, benchmarks come from k-fold cross-validation51,52, and we followed the exact same procedures in order to allow for a fair comparison.Figure 4Examples of images from each of the datasets.(a) RSMAS (b) EILAT (c) ZooLake (d) WHOI (e) Kaggle (f) ZooScan (g) NA-Birds (h) Stanford dogs (i) SriLankan Beetles (j) Florida Wildtrap.Full size imageRSMAS This is a small coral dataset of 766 RGB image patches with a size of (256times 256) pixels each53. The patches were cropped out of bigger images obtained by the University of Miami’s Rosenstiel School of Marine and Atmospheric Sciences. These images were captured using various cameras in various locations. The data is separated into 14 unbalanced groups and whose labels correspond to the names of the coral species in Latin. The current SOTA for the classification of this dataset is by52. They use the ensemble of best performing 11 CNN models. The best models were chosen based on sequential forward feature selection (SFFS) approach. Since an independent test is not available, they make use of 5-fold cross-validation for benchmarking the performances.EILAT This is a coral dataset of 1123 64-pixel RGB image patches53 that were created from larger images that were taken from coral reefs near Eilat in the Red sea. The image dataset is partitioned into eight classes, with an unequal distribution of data. The names of the classes correspond to the shorter version of the scientific names of the coral species. The current SOTA52 for the classification of this dataset uses the ensemble of best performing 11 CNN models similar to RSMAS dataset and 5-fold cross-validation for benchmarking the performances.ZooLake This dataset consists of 17943 images of lake plankton from 35 classes, acquired using a Dual-magnification Scripps Plankton Camera (DSPC) in Lake Greifensee (Switzerland) between 2018 and 2020 14,54. The images are colored, with a black background and an uneven class distribution. The current SOTA22 on this dataset is based on a stacking ensemble of 6 CNN models on an independent test set.WHOI This dataset 55 contains images of marine plankton acquired by Image FlowCytobot56, from Woods Hole Harbor water. The sampling was done between late fall and early spring in 2004 and 2005. It contains 6600 greyscale images of different sizes, from 22 manually categorized plankton classes with an equal number of samples for each class. The majority of the classes belonging to phytoplankton at genus level. This dataset was later extended to include 3.4M images and 103 classes. The WHOI subset that we use was previously used for benchmarking plankton classification models51,52. The current SOTA22 on this dataset is based on average ensemble of 6 CNN models on an independent test set.Kaggle-plankton The original Kaggle-plankton dataset consists of plankton images that were acquired by In-situ Ichthyoplankton Imaging System (ISIIS) technology from May to June 2014 in the Straits of Florida. The dataset was published on Kaggle (https://www.kaggle.com/c/datasciencebowl) with images originating from the Hatfield Marine Science Center at Oregon State University. A subset of the original Kaggle-plankton dataset was published by51 to benchmark the plankton classification tasks. This subset comprises of 14,374 greyscale images from 38 classes, and the distribution among classes is not uniform, but each class has at least 100 samples. The current SOTA22 uses average ensemble of 6 CNN models and benchmarks the performance using 5-fold cross-validation.ZooScan The ZooScan dataset consists of 3771 greyscale plankton images acquired using the Zooscan technology from the Bay of Villefranche-sur-mer57. This dataset was used for benchmarking the classification models in previous plankton recognition papers51,52. The dataset consists of 20 classes with a variable number of samples for each class ranging from 28 to 427. The current SOTA22 uses average ensemble of 6 CNN models and benchmarks the performance using 2-fold cross-validation.NA-Birds NA-Birds58 is a collection of 48,000 captioned pictures of North America’s 400 most often seen bird species. For each species, there are over 100 images accessible, with distinct annotations for males, females, and juveniles, totaling 555 visual categories. The current SOTA59 called TransFG modifies the pure ViT model by adding contrastive feature learning and part selection module that replaces the original input sequence to the transformer layer with tokens corresponding to informative regions such that the distance of representations between confusing subcategories can be enlarged. They make use of an independent test set for benchmarking the model performances.Stanford Dogs The Stanford Dogs dataset comprises 20,580 color images of 120 different dog breeds from all around the globe, separated into 12,000 training images and 8,580 testing images60. The current SOTA59 makes use of modified ViT model called TransFG as explained above in NA-Birds dataset. They make use of an independent test set for benchmarking the model performances.Sri Lankan Beetles The arboreal tiger beetle data61 consists of 380 images that were taken between August 2017 and September 2020 from 22 places in Sri Lanka, including all climatic zones and provinces, as well as 14 districts. Tricondyla (3 species), Derocrania (5 species), and Neocollyris (1 species) were among the nine species discovered, with six of them being endemic . The current SOTA61 makes use of CNN-based SqueezeNet architecture and was trained using pre-trained weights of ImageNet. The benchmarking of the model performances was done on an independent test set.Florida Wild Traps The wildlife camera trap62 classification dataset comprises 104,495 images with visually similar species, varied lighting conditions, skewed class distribution, and samples of endangered species, such as Florida panthers. These were collected from two locations in Southwestern Florida. These images are categorized in to 22 classes. The current SOTA62 makes use of CNN-based ResNet-50 architecture and the performance of the model was benchmarked on an independent test set.ModelsVision transformers (ViTs)31 are an adaptation to computer vision of the Transformers, which were originally developed for natural language processing30. Their distinguishing feature is that, instead of exploiting translational symmetry, as CNNs do, they have an attention mechanism which identifies the most relevant part of an image. ViTs have recently outperformed CNNs in image classification tasks where vast amounts of training data and processing resources are available30,63. However, for the vast majority of use cases and consumers, where data and/or computational resources are limiting, ViTs are essentially untrainable, even when the network architecture is defined and no architectural optimization is required. To settle this issue, Data-efficient Image Transformers (DeiTs) were proposed32. These are transformer models that are designed to be trained with much less data and with far less computing resources32. In DeiTs, the transformer architecture has been modified to allow native distillation64, in which a student neural network learns from the results of a teacher model. Here, a CNN is used as the teacher model, and the pure vision transformer is used as the student network. All the DeiT models we report on here are DeiT-Base models32. The ViTs are ViT-B16, ViT-B32, and ViT-L32 models31.ImplementationTo train our models, we used transfer learning65: we took a model that was already pre-trained on the ImageNet43 dataset, changed the last layers depending on the number of classes, and then fine-tuned the whole network with a very low learning rate. All the models were trained with two Nvidia GTX 2080Ti GPUs.DeiTs We used DeiT-Base32 architecture, using the Python package TIMM66, which includes many of the well-known deep learning architectures, along with their pre-trained weights computed from the ImageNet dataset43. We resized the input images to 224 x 224 pixels and then, to prevent the model from overfitting at the pixel level and help it generalize better, we employed typical image augmentations during training such as horizontal and vertical flips, rotations up to 180 degrees, small zoom up’s to 20%, a small Gaussian blur, and shearing up to 10%. To handle class imbalance, we used class reweighting, which reweights errors on each example by how present that class is in the dataset67. We used sklearn utilities68 to calculate the class weights which we employed during the training phase.The training phase started with a default pytorch69 initial conditions (Kaiming uniform initializer), an AdamW optimizer with cosine annealing70, with a base learning rate of (10^{-4}), and a weight decay value of 0.03, batch size of 32 and was supervised using cross-entropy loss. We trained with early stopping, interrupting training if the validation F1-score did not improve for 5 epochs. The learning rate was then dropped by a factor of 10. We iterated until the learning rate reached its final value of (10^{-6}). This procedure amounted to around 100 epochs in total, independent of the dataset. The training time varied depending on the size of the datasets. It ranged between 20min (SriLankan Beetles) to 9h (Florida Wildtrap). We used the same procedure for all the datasets: no extra time was needed for hyperparameter tuning.ViTs We implemented the ViT-B16, ViT-B32 and ViT-L32 models using the Python package vit-keras (https://github.com/faustomorales/vit-keras), which includes pre-trained weights computed from the ImageNet43 dataset and the Tensorflow library71.First, we resized input images to 128 × 128 and employed typical image augmentations during training such as horizontal and vertical flips, rotations up to 180 degrees, small zooms up to 20%, small Gaussian blur, and shearing up to 10%. To handle class imbalance, we calculated the class weights and use them during the training phase.Using transfer learning, we imported the pre-trained model and froze all of the layers to train the model. We removed the last layer, and in its place we added a dense layer with (n_c) outputs (being (n_c) the number of classes), was preceded and followed by a dropout layer. We used the Keras-tuner72 with Bayesian optimization search73 to determine the best set of hyperparameters, which included the dropout rate, learning-rate, and dense layer parameters (10 trials and 100 epochs). After that, the model with the best hyperparameters was trained with a default tensorflow71 initial condition (Glorot uniform initializer) for 150 epochs using early stopping, which involved halting the training if the validation loss did not decrease after 50 epochs and retaining the model parameters that had the lowest validation loss.CNNs CNNs included DenseNet38, MobileNet39, EfficientNet-B240, EfficientNet-B540, EfficientNet-B640, and EfficientNet-B740 architectures. We followed the training procedure described in Ref.22, and carried out the training in tensorflow.Ensemble learningWe adopted average ensembling, which takes the confidence vectors of different learners, and produces a prediction based on the average among the confidence vectors. With this procedure, all the individual models contribute equally to the final prediction, irrespective of their validation performance. Ensembling usually results in superior overall classification metrics and model robustness74,75.Given a set of n models, with prediction vectors (vec c_i~(i=1,ldots ,n)), these are typically aggregated through an arithmetic average. The components of the ensembled confidence vector (vec c_{AA}), related to each class (alpha ) are then$$begin{aligned} c_{AA,alpha } = frac{1}{n}sum _{i=1}^n c_{i,alpha },. end{aligned}$$
(2)
Another option is to use a geometric average,$$begin{aligned} c_{GA,alpha } = root n of {prod _{i=1}^n c_{i,alpha }},. end{aligned}$$
(3)
We can normalize the vector (vec c_g), but this is not relevant, since we are interested in its largest component, (displaystyle max _alpha (c_{GA,alpha })), and normalization affects all the components in the same way. As a matter of fact, also the nth root does not change the relative magnitude of the components, so instead of (vec c_{GA}) we can use a product rule: (displaystyle max _alpha (c_{GA,alpha })=max _alpha (c_{PROD,alpha })), with (displaystyle c_{PROD,alpha } = prod _{i=1}^n c_{i,alpha }).While these two kinds of averaging are equivalent in the case of two models and two classes, they are generally different in any other case33. For example, it can easily be seen that the geometric average penalizes more strongly the classes for which at least one learner has a very low confidence value, a property that was termed veto mechanism36 (note that, while in Ref.36 the term veto is used when the confidence value is exactly zero, here we use this term in a slightly looser way). More