Deep Learning Assignment¶

Table of Contents¶

  • Introduction
  • Data analysis and preparation
  • Overall experiments summary
  • Training process
  • Result interpretation
  • Making predictions

Introduction ¶

This notebook serves as a report for the Deep Learning Course assignment. In the following, I will summarize the results obtained in the task of identifying oil palm crops in PlanetScope imagery via Deep Learning.

For this task, several methods learned in trough the semester were proved in combination with external research. The proposed framework involves high-level APIs to easily deal with the training (fine-tuning) and evaluation of the different state of the art architectures for the task of binary image classification.

The notebook is organized as follows. The first section is where data is presented and briefly described as well as operations made directly on the dataframe to prepare it for tranning. In a second section, the overall experiment framework and workflow is described and some resuls are summarized. Those results allow to determine the best model for the task, which is presented in the third section along with the detailed process of training.

Installing dependencies.¶

Warning: Other than base python and other common libraries, and PyTorch which is assumed to be already installed, this notebook relays on Fast.ai and Timm packages. Fast.ai is a research group that aims to democratize deep learning, their slogan is Making neural nets uncool again and it provides a python package that serves as a high*level API to deal with deep model training and evaluation. Timm on the other hand is one of the biggest and most popular pre-trained model repositories from where different weights are going to be taken and used. Both packages are installed through pip which means PyTorch has to be already installed.
In [1]:
!pip install fastai
!pip install timm
Collecting fastai
  Downloading fastai-2.7.10-py3-none-any.whl (240 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 240.9/240.9 kB 3.0 MB/s eta 0:00:00a 0:00:01
Collecting fastprogress>=0.2.4
  Downloading fastprogress-1.0.3-py3-none-any.whl (12 kB)
Collecting fastcore<1.6,>=1.4.5
  Downloading fastcore-1.5.27-py3-none-any.whl (67 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.1/67.1 kB 14.1 MB/s eta 0:00:00
Requirement already satisfied: torchvision>=0.8.2 in /usr/local/lib/python3.9/dist-packages (from fastai) (0.13.0+cu116)
Requirement already satisfied: pillow>6.0.0 in /usr/local/lib/python3.9/dist-packages (from fastai) (9.2.0)
Requirement already satisfied: pip in /usr/local/lib/python3.9/dist-packages (from fastai) (22.2.2)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.9/dist-packages (from fastai) (5.4.1)
Collecting fastdownload<2,>=0.0.5
  Downloading fastdownload-0.0.7-py3-none-any.whl (12 kB)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/dist-packages (from fastai) (3.5.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from fastai) (1.4.3)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.9/dist-packages (from fastai) (1.1.1)
Requirement already satisfied: torch<1.14,>=1.7 in /usr/local/lib/python3.9/dist-packages (from fastai) (1.12.0+cu116)
Requirement already satisfied: scipy in /usr/local/lib/python3.9/dist-packages (from fastai) (1.8.1)
Requirement already satisfied: spacy<4 in /usr/local/lib/python3.9/dist-packages (from fastai) (3.4.0)
Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from fastai) (2.28.1)
Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from fastai) (21.3)
Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (3.3.0)
Requirement already satisfied: typer<0.5.0,>=0.3.0 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (0.4.2)
Requirement already satisfied: wasabi<1.1.0,>=0.9.1 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (0.9.1)
Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (2.0.7)
Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (1.0.7)
Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.9 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (3.0.9)
Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (1.0.2)
Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (4.64.0)
Requirement already satisfied: thinc<8.2.0,>=8.1.0 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (8.1.0)
Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (3.0.6)
Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (2.0.6)
Requirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (63.1.0)
Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (2.4.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (3.1.2)
Requirement already satisfied: pathy>=0.3.5 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (0.6.2)
Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.10.0,>=1.7.4 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (1.9.1)
Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.9/dist-packages (from spacy<4->fastai) (1.23.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.9/dist-packages (from packaging->fastai) (3.0.9)
Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests->fastai) (2.8)
Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.9/dist-packages (from requests->fastai) (2.1.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests->fastai) (2019.11.28)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->fastai) (1.26.10)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch<1.14,>=1.7->fastai) (4.3.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib->fastai) (0.11.0)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.9/dist-packages (from matplotlib->fastai) (2.8.2)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->fastai) (4.34.4)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->fastai) (1.4.3)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->fastai) (2022.1)
Requirement already satisfied: joblib>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->fastai) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->fastai) (3.1.0)
Requirement already satisfied: smart-open<6.0.0,>=5.2.1 in /usr/local/lib/python3.9/dist-packages (from pathy>=0.3.5->spacy<4->fastai) (5.2.1)
Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.7->matplotlib->fastai) (1.14.0)
Requirement already satisfied: blis<0.8.0,>=0.7.8 in /usr/local/lib/python3.9/dist-packages (from thinc<8.2.0,>=8.1.0->spacy<4->fastai) (0.7.8)
Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.9/dist-packages (from typer<0.5.0,>=0.3.0->spacy<4->fastai) (8.1.3)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->spacy<4->fastai) (2.1.1)
Installing collected packages: fastprogress, fastcore, fastdownload, fastai
Successfully installed fastai-2.7.10 fastcore-1.5.27 fastdownload-0.0.7 fastprogress-1.0.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Collecting timm
  Downloading timm-0.6.12-py3-none-any.whl (549 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 549.1/549.1 kB 30.7 MB/s eta 0:00:00
Requirement already satisfied: torch>=1.7 in /usr/local/lib/python3.9/dist-packages (from timm) (1.12.0+cu116)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.9/dist-packages (from timm) (5.4.1)
Requirement already satisfied: torchvision in /usr/local/lib/python3.9/dist-packages (from timm) (0.13.0+cu116)
Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.9/dist-packages (from timm) (0.8.1)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.7->timm) (4.3.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (from huggingface-hub->timm) (4.64.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from huggingface-hub->timm) (3.7.1)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub->timm) (21.3)
Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from huggingface-hub->timm) (2.28.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.9/dist-packages (from torchvision->timm) (9.2.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from torchvision->timm) (1.23.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.9/dist-packages (from packaging>=20.9->huggingface-hub->timm) (3.0.9)
Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests->huggingface-hub->timm) (2.8)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->huggingface-hub->timm) (1.26.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests->huggingface-hub->timm) (2019.11.28)
Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.9/dist-packages (from requests->huggingface-hub->timm) (2.1.0)
Installing collected packages: timm
Successfully installed timm-0.6.12
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
In [105]:
# Importing fastai. Not my favorite style, nor efficient, way of importing, but is the way they suggest to do it
import fastai
from fastai.vision.all import *
from fastai.vision.learner import has_pool_type

# Importing timm pre-trained model library
import timm

# Other imports
import os
from pathlib import Path
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.utils import class_weight

#An aux functions to plot metrics of the traning process

@patch
@delegates(subplots)
def plot_metrics(self: Recorder, nrows=None, ncols=None, figsize=None, **kwargs):
    metrics = np.stack(self.values)
    names = self.metric_names[1:-1]
    n = len(names) - 1
    if nrows is None and ncols is None:
        nrows = int(math.sqrt(n))
        ncols = int(np.ceil(n / nrows))
    elif nrows is None: nrows = int(np.ceil(n / ncols))
    elif ncols is None: ncols = int(np.ceil(n / nrows))
    figsize = figsize or (ncols * 6, nrows * 4)
    fig, axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
    axs = [ax if i < n else ax.set_axis_off() for i, ax in enumerate(axs.flatten())][:n]
    for i, (name, ax) in enumerate(zip(names, [axs[0]] + axs)):
        ax.plot(metrics[:, i], color='#1f77b4' if i == 0 else '#ff7f0e', label='valid' if i > 0 else 'train')
        ax.set_title(name if i > 1 else 'losses')
        ax.legend(loc='best')
    plt.show()

Data analysis and preparation¶

The data is provided as a subset of the dataset proposed by Kaggle for the Women in Science Datathon 2019 where Kaggle partnered with Planet labs to obtain a set of labeled imagery containing both images with and without oil palm crops.

This converts this task into a computer vision binary classification task. but before training any model a first look at the data will give hints on the nature of the phenomena as well as the challenges in the dataset.

In [3]:
labels = Path('traindata.csv') # Assuming the .csv is in the same folder
df = pd.read_csv(labels)

The data is provided as a table with an image identificator followed by the label. This will be important for the construction of the labeler during the training phase. Cero means no palm and 1 otherwise.

In [4]:
df.head()
Out[4]:
img_id has_oilpalm
0 train/img_0000.jpg 0
1 train/img_0001.jpg 0
2 train/img_0002.jpg 0
3 train/img_0003.jpg 0
4 train/img_0004.jpg 0

There is a clear imbalance between both classes, almost 90-10. This may add some complexity to the classification task and the predictive capacities of the model.

In [5]:
print(f'The total number of observations in the train set is {df.has_oilpalm.count()} from which {df.has_oilpalm.value_counts()[0]} ({df.has_oilpalm.value_counts()[0]/df.has_oilpalm.count()*100}%) images dont present oil palm crops while {df.has_oilpalm.value_counts()[1]} ({df.has_oilpalm.value_counts()[1]/df.has_oilpalm.count()*100}%) do. ')
The total number of observations in the train set is 7594 from which 6650 (87.56913352646826%) images dont present oil palm crops while 944 (12.430866473531736%) do. 
In [6]:
df.has_oilpalm.value_counts().plot(kind="bar",title="Has Oil Palm", rot=15,xlabel="Cut",ylabel="Count")
Out[6]:
<AxesSubplot:title={'center':'Has Oil Palm'}, xlabel='Cut', ylabel='Count'>

Each class is displayed to asses which differential characteristics are we looking for.

In general, it can be said that the color information in both classes is quite similar as both have mostly vegetation. The most differential characteristic migth be related the borders and patterns created out of anthropogenic activity such as ways in between crops or homogeneity in the position of the trees.

In [7]:
fig = plt.figure(figsize=(10, 7))
images=np.hstack([np.array((df[df['has_oilpalm'] == 0].sample(n=3,random_state=2)).img_id),np.array((df[df['has_oilpalm'] == 1].sample(n=3,random_state=2)).img_id)])

for i in range(len(images)):
    image=Image.open(images[i])
    title='0' if i<=int(len(images)/2)-1 else '1'
    
    fig.add_subplot(2, int(len(images)/2), i+1)
    plt.imshow(image)
    plt.axis('off')
    plt.title(title)

Two relevant decisions follow the previous analysis. First, the labels of the images will be transformed into True (1) and False (0) as they are more meaningful and appropriate for further stages. Second and most important to asses the ipact of the class imbalance, a simple oversampling method in the underrepresented class is performed by multiplicating the registries until the balance is reached. this decisions will later tasted against their counterparts namely no label tuning and no oversampling. The oversamplig method is also combined with a stratified split of trin and test sets. Furthermore, the weight of each class is calculated.

In [8]:
df['has_oilpalm']=['True' if element==1 else 'False' for element in df['has_oilpalm']] # Labels re-name

# Estimating class weights
y_train=np.array(df.has_oilpalm)
class_weight = dict(zip(np.unique(y_train), class_weight.compute_class_weight('balanced',classes=np.unique(y_train),y=y_train)))

# Oversamplig of True class by multiplying records

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=0)
sss.get_n_splits(df.img_id, df.has_oilpalm)
for i, (train_index, test_index) in enumerate(sss.split(df.img_id,  df.has_oilpalm)):
     df_train=df.iloc[list(train_index)]
     df_val=df.iloc[list(test_index)]
df_train.insert(2,'is_valid',np.tile(False,len(df_train)))
df_val.insert(2,'is_valid',np.tile(True,len(df_val)))
In [9]:
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
train=list(df_train.has_oilpalm.value_counts()*100/len(df_train))
test=list(df_val.has_oilpalm.value_counts()*100/len(df_val))

X = np.arange(2)

ax.bar(X + 0.00, train, color = 'b', width = 0.25)
ax.bar(X + 0.25, test, color = 'g', width = 0.25)
plt.xticks(X+0.15, ['False','True'])
ax.set_title('Percentage of True/false in splitted datasets')
ax.legend(labels=['Test', 'Validation'])
Out[9]:
<matplotlib.legend.Legend at 0x7fc864cb59a0>
In [10]:
df2 = df_train[df_train.has_oilpalm != 'False']
df3 = df_train[df_train.has_oilpalm == 'False']
df3 = df3.sample(frac = 0.4) # S

df_balance = pd.concat([df3,df_val,df2,df2,df2], ignore_index=True)
df_balance = df_balance.sample(frac = 1) # Shufle the data

df = pd.concat([df_train,df_val], ignore_index=True)
In [11]:
print(f'The total number of observations in the oversampled train set is {df_balance.has_oilpalm.count()} from which {df_balance.has_oilpalm.value_counts()[0]} ({df_balance.has_oilpalm.value_counts()[0]/df_balance.has_oilpalm.count()*100}%) images dont present oil palm crops while {df_balance.has_oilpalm.value_counts()[1]} ({df_balance.has_oilpalm.value_counts()[1]/df_balance.has_oilpalm.count()*100}%) do. ')
The total number of observations in the oversampled train set is 5807 from which 3259 (56.121921818494926%) images dont present oil palm crops while 2548 (43.87807818150508%) do. 
In [12]:
len(df_balance)
Out[12]:
5807
In [13]:
df.has_oilpalm.value_counts().plot(kind="bar",title="Has Oil Palm", rot=15,xlabel="Cut",ylabel="Count")
Out[13]:
<AxesSubplot:title={'center':'Has Oil Palm'}, xlabel='Cut', ylabel='Count'>
In [14]:
df_balance.has_oilpalm.value_counts().plot(kind="bar",title="Has Oil Palm", rot=15,xlabel="Cut",ylabel="Count")
Out[14]:
<AxesSubplot:title={'center':'Has Oil Palm'}, xlabel='Cut', ylabel='Count'>

Overall experiments summary¶

This section describes the overall framework and workflow followed to build the final predictive model.

There are some relevant points to mention about the approch followed to solve the task:

  • Transfer learning: For the sake of achieving the best possible accuracy with low training effort, this technique is implemented to use the pre-trained weights of state-of-the-art models as starting point for the construction of the new model. In this aspect, several architectures were tested and compared in a deterministic fashion (fixed random states) and following the same pipeline which is described in the next section.
  • Data Augmentation: To avoid model overfitting, data augmentation is implemented, data is augmented with common vision transformations.
  • Efficiency:: The model architectures chosen are the ones that promise excellent results at the expense of relatively little energy and time consumption. More complex models were discarded.
  • Best practices: The recommendations of the libraries used in the framework and their researchers are followed to guarantee the best possible result with the cleanest possible implementation. More details on this matter are given in the next section.
  • Simplicity: As this is a task that has been approached several times no additional complexity is to be added. I will follow the common practices in deep learning and binary image classification through high-level APIs and well-known SOTA models.

To achieve the training and comparison of the models as well as the final result following the previously mentioned there are three basic tools (packages) used.

  • PyTorch: As the currently dominant Deep Learning module and the one studied in the classes. Furthermore, all pre-trained models can be found implemented in PyTorch.
  • Fastai: This library greatly simplifies the interaction with PyTorch. It provides APIs for easy and direct data loading, augmenting, and feeding as well as single-line model training and fine-tuning along with useful aux functions. It is important to mention that this implementation was done with fastai-2.7.10-py3. Despite being a great tool, the inconsistency between versions and the abundance of material (generally a great thing) prohibits us to reach straightforward results. Once the version is established and the appropriate documentation-materials are found, the solution becomes evident.
  • timm: This repository contains practically all available pre-trained SOTA models for image classification implemented in PyTorch and has great interoperation capabilities with fastai.

A baseline model was trained following these specifications: pre-trained Resnet50 as feature extractor with 10 training epochs for the classifier layers with no oversampling and data augmentation on the dataset and cross-entropy as loss function. From that point, several variations were attempted to improve the results as will be described non of them achieve several improvements in the accuracy over the validation set. Those changes were:

  • Cross-validation: a cross-validation procedure was attempted with some of the tested architectures. The approach was to implement a voting ensemble but in any case, this approach improved the accuracy of predictions.
  • Balancing data: Both the over-sampled balanced data set and the original dataset was used to train each model architecture. The balanced training dataset did not perform better than the imbalanced one in any case.
  • One neuron BCE vs two neurons CE output: For binary classification problem it is possible to output one single neuron followed by a sigmoid and binary cross entropy loss or two neurons followed by either a softmax function and cross-entropy loss or one BCE for each neuron. In the end, although the general practice is to use a single neuron with BCE, the results did not show any change between this approach (using several thresholds) and the two neurons' output CE. As far as I understand, if the threshold is not meaningful, both approaches and even the third one with two neurons and individual BCE will lead to the same result. Fast.ai API tackles every classification as a multiclass single label the output neurons have the same number as the number of classes; as mentioned before for simplicity this will not be changed as this has no impact on the final result and would imply modifying working existent code from the library. I am aware that this might not be the best practice in general but specifically Fast.ai proceeds that way. It is important to remark that this decision was taken after every combination for the output was re-implemented and tested.
  • Loss function change: Other loss functions were tested along with their weighted versions to account for class imbalance. The three main functions used were binary cross-entropy, categorical cross-entropy, and focal loss. All functions performed similarly but the best and most stable results were obtained with a focal loss with weights.

The comparations between the best-performing models are shown below. All of them were trained with the same dataset and the same loss function but different number of epochs and learning rates.

224
Architecture Validation Loss F1-Score Number of parameters
Resnet-50 0.32 0.72 25.56 million
EfficientNetB3 pruned 0.34 0.69 9.86 million
Resnext50r 0.36 0.63224 million
ConvNext Nano 0.4 0.67 15.59 million

Traning process ¶

Fast.ai offers several tools to make data ingestion into PyTorch easy. The main structures to use are DataBlocks and DataLoadres. A DataBlock defines the type of data being ingested (images), the nature of the task (classification), and several aux functions to read the labels, get the information from where it is stored, and the definitions of transformation to be made on CPU or GPU. This object is then passed to a DataLoader which takes charge of actually feeding that information into the training loop.

In [48]:
#def get_y(row): return [row['has_oilpalm']]
palm1 = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   splitter=ColSplitter('is_valid'),
                   get_x=ColReader('img_id',pref=""),
                   get_y=ColReader('has_oilpalm'),
                   batch_tfms=aug_transforms(flip_vert=True, max_lighting=0.1, max_zoom=1.05, max_warp=0.)+[Normalize.from_stats(*imagenet_stats)])
dls1 = palm1.dataloaders(df,bs=64)
dls1.show_batch()

The training process can be resumed in three lines of code or less with Fast.ai. In the following lines, two definitions are made. First, a loss function that can be changed between several losses incorporated in the Fast.ai library or custom functions can be defined. The F1 score metric is also defined to be monitored during the training. Finally, the Learner takes the DataLoader along with a model name (if timm is installed every model in that library is usable) the loss function, and the flag to define if pre-trained weights are to be used.

The exact same structure was used to test the previously mentioned and summarized experiments just by changing either the method relative to lose or the string with each model name. Here only the selected model is shown.

In [49]:
fl = CrossEntropyLossFlat(torch.tensor([0.5709774436090226,4.022245762711864]).to("cuda:0"))
f1_score=F1Score()
learn = vision_learner(dls1,'resnet50',loss_func=fl,pretrained=True, metrics=[f1_score])

Fast.ai interface also offers a convenient function to find some sort of proper learning rate. The approach consists of training minibathces of the model with increasing learning rates. The suggested learning rate is based on the slope ad inflection point of the results.

In [44]:
lr_head_1=learn.lr_find()

This learning rate is used together with the fit one-cycle policy which uses the concept of Cyclic learning rates to converge faster and with higher accuracies than a fixed rate. the number of epochs was tuned based on the moment where each model started overfitting.

In [45]:
learn.fit_one_cycle(10, lr_head_1.valley)
epoch train_loss valid_loss f1_score time
0 0.777247 0.639669 0.433898 00:51
1 0.663270 0.442420 0.552036 00:51
2 0.487880 0.412646 0.660969 00:51
3 0.408410 0.365458 0.664879 00:51
4 0.372332 0.370365 0.689459 00:51
5 0.366100 0.357190 0.687324 00:51
6 0.339802 0.337202 0.696629 00:51
7 0.341606 0.312336 0.699422 00:51
8 0.331789 0.325696 0.700565 00:52
9 0.319722 0.321493 0.715116 00:51
In [47]:
learn.recorder.plot_metrics()
In [55]:
learn.export('1stage_finetuned_resnet_10')

Result interpretation ¶

From the confusion matrix, it is clear that the problematic characteristic is false positives. This behavior was predominant between all model architectures and other experiments made. Even the number of misclassified images is similar between experiments. That was the main reason for not being able to increase the F1 score by more than 0.7 overall.

In [53]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(title='Confusion matrix')

It is possible as it was said that these images have been mislabeled but it is not certain. Several of these samples, especially those being fake positives might have characteristics not observable with this simple RGB imagery. The phenomenological state of the palm or the presence of different plants is not differentiable without further data for example infrared bands or others.

Following the impossibility to be completely sure that this data has bad labels, the original dataset is not modified. If that I the case then the model is capable of correcting such behavior by detecting ad labels and suggesting the change. If not, there is no way, or at least i could not find a way, to improve the classification of the given data without improving its quality or adding covariates.

In [54]:
interp.plot_top_losses(k=60)

Making predictions ¶

In [99]:
dir_list = os.listdir('test/')
dir_list=['test/'+filename for filename in dir_list]
d={'img_id':dir_list}
df_test = pd.DataFrame(data=d)
#df_test['has_oilpalm']=['True' if element==1 else 'False' for element in df_test['has_oilpalm']] # Labels re-name

# Estimating class weight
test_dl = learn.dls.test_dl(df_test)
preds, _, decoded = learn.get_preds(dl=test_dl, with_decoded=True)
df_test.insert(1,'has_oilpalm',np.tile(False,len(df_test)))
df_test['has_oilpalm']=[1 if element=='True' else 0 for element in df_test['has_oilpalm']] # Labels re-name
In [100]:
df_test.to_csv('testdata.csv')
In [102]:
df_test.head()
Out[102]:
img_id has_oilpalm
0 test/img_5989.jpg 0
1 test/img_2836.jpg 0
2 test/img_2855.jpg 0
3 test/img_5765.jpg 0
4 test/img_6675.jpg 0
In [103]:
import matplotlib.pyplot as plt
plt.hist(decoded)
Out[103]:
(array([6336.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
        1301.]),
 array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ]),
 <BarContainer object of 10 artists>)