Authors: Jose Gomez, Oscar Narvaez, Max Aragon
Affiliation: Department of Geoinformatics, University of Salzburg
Copyright: Ⓒ 2022
License: MIT

Cloud Cover spatiotemporal prediction using deep learning


Introduction¶

In this Jupyter Notebook you will learn how to query ERA5 and Sentinel-3 data from WEkEO, explore it, and pre-process it for predicting the spatiotemporal variability of total cloud cover over France.

Data¶

Product Description WEkEO HDA ID WEkEO metadata
Sentinel-3 OLCI level-2 EO:ESA:DAT:SENTINEL-3:OL_2_LFR___ link
ERA5 EO:ECMWF:DAT:REANALYSIS_ERA5_SINGLE_LEVELS link

Data description¶

ERA5 is a reanalysis dataset providing hourly estimates of a large number of atmospheric, land and oceanic climate variables. The data covers the Earth on a 30km grid and resolve the atmosphere using 137 levels from the surface up to a height of 80km. ERA5 includes information about uncertainties for all variables at reduced spatial and temporal resolutions.

From this dataset we use the Total Cloud Cover (TCC) estimation. This parameter is the proportion of a grid box covered by cloud. Total cloud cover is a single level field calculated from the cloud occurring at different model levels through the atmosphere. Assumptions are made about the degree of overlap/randomness between clouds at different heights. Cloud fractions vary from 0 to 1.

In addition, we use Sentinel-3 Ocean and Land Colour Instrument (OLCI) cloud mask product to compare TCC values. OLCI has a resolution of 300m which is 100 times more detailed than ERA5, thus allowing us to evaluate the performance of the reanalysis with finer observations.

Hold on, what's a climate reanalysis?¶

In [50]:
from IPython.display import YouTubeVideo

YouTubeVideo('FAGobvUGl24', width=800, height=450)
Out[50]:

Contents¶

  • Data access and download
    • Importing libraries
    • Creating and sending download query
  • Data Exploration and pre processing
    • Importing libraries
    • Loading data
    • Plotting temporal data
    • Plotting geographical data
    • Animating spatio-temporal data
  • Next-frame prediction
    • Importing libraries
    • Defining the model
    • Tranning the model
    • Making predictions

Data access and download ¶

Importing libraries ¶

We will use the Harmonized Data Access service to search and download our data. For this, Wekeo has created a python client for the API which makes this process easier. Next we will import that Client, define the query json and send a request for downloading ERA5 TCC data. Currently this data is provided world-wide, so it must be cropped according to a specific AOI.

In [51]:
import hda # Wekeo's python Client for HDA 
import os # To manage folder basic operations
import zipfile # to unzip files
from datetime import datetime # Python datetime library
import warnings
warnings.filterwarnings('ignore')

ERA5 download query ¶

Next we need to define a JSON: a dictionary of data to communicate with the API. This JSON contains information such as ERA5 collection ID, our variable of interest and the time range we look for. In the following block we generate lists of years, months and days to facilitate hanlding. For this example we will use hourly data from 2017 to 2021, including all months of the year. We will also define the directory in which we want the data to be downloaded.

In [52]:
#Sequences for years, months and days, note that you need to add one unit to the desired number.
Years=list(map(str,range(2017,2022,1))) 
months=list(map(str,range(1,13,1)))
days=list(map(str,range(1,31,1)))
# Working dir (download data)
workdir='.'
In [53]:
Years #How the years sequence looks like
Out[53]:
['2017', '2018', '2019', '2020', '2021']
In [54]:
query = {
  "datasetId": "EO:ECMWF:DAT:REANALYSIS_ERA5_SINGLE_LEVELS", # This is the ID of ERA5 dataset.

  "multiStringSelectValues": [
    {
      "name": "variable",
      "value": [
        "total_cloud_cover" # The name of our variable of interest.
      ]
    },
    {
      "name": "product_type",
      "value": [
        "reanalysis"
      ]
    },
    {
      "name": "year",
      "value": Years # We use the generated lists to fill these fields.
    },
    {
      "name": "month",
      "value": months
    },
    {
      "name": "day",
      "value": days
    },
    {
      "name": "time",
      "value": [
        "00:00",
        "01:00",
        "02:00",
        "03:00",
        "04:00",
        "05:00",
        "06:00",
        "07:00",
        "08:00",
        "09:00",
        "10:00",
        "11:00",
        "12:00",
        "13:00",
        "14:00",
        "15:00",
        "16:00",
        "17:00",
        "18:00",
        "19:00",
        "20:00",
        "21:00",
        "22:00",
        "23:00"
      ]
    }
  ],
  "stringChoiceValues": [
    {
      "name": "format",
      "value": "netcdf"
    }
  ]
}
Warning: To create the client instance, you must have an .hdarc file with your credentials in your home directory. If you don't have it or want to be sure, run the following lines. You can also use them to update credentials.
In [55]:
username=''
password=''
HOME=os.environ['HOME']
fileString=(f'url: https://wekeo-broker.apps.mercator.dpi.wekeo.eu/databroker \nuser: {username} \npassword: {password}')
file_path=os.path.join(HOME,'.hdarc')
if os.path.exists(file_path):
    os.remove(file_path)
with open(file_path,'w') as f:
    f.write(fileString)

Now we create an hda client instance and pass the query. The next block will search for the data and download it to your working directory, it might take long as it will only stop when all the data is downloaded.

In [56]:
dir_name = workdir + '/data/ERA5' 
if os.path.isdir(dir_name):
    if not os.listdir(dir_name):
        c=hda.Client(debug=True)  # Creating the client
        matches=c.search(query) # Looking for data based on query json
        os.chdir(workdir)
        matches.download() # Actually making the download
    else:    
        print("the ERA5 data has already been downloaded")
else:
    print("the given directory doesn't exist")
the ERA5 data has already been downloaded

Sentinel-3 OLCI download query¶

In [57]:
query_202207 = {
    "datasetId": "EO:ESA:DAT:SENTINEL-3:OL_2_LFR___",
    "boundingBoxValues": [{"name": "bbox","bbox": [0.8592794892475406,46.05897584821409,4.208065051226144,48.04894076163025]}],
    "dateRangeSelectValues": [{"name": "position", "start": "2022-07-29T10:00:00.000Z","end": "2022-07-29T10:02:00.000Z"}],
    "stringChoiceValues": [{"name": "productType","value": "LFR"},{"name": "processingLevel","value": "LEVEL2"},{"name": "timeliness","value": "Near+Real+Time"},{"name": "orbitDirection","value": "descending"}]
}

query_202206 = {
    "datasetId": "EO:ESA:DAT:SENTINEL-3:OL_2_LFR___",
    "boundingBoxValues": [{"name": "bbox","bbox": [0.8592794892475406,46.05897584821409,4.208065051226144,48.04894076163025]}],
    "dateRangeSelectValues": [{"name": "position", "start": "2022-06-05T10:00:00.000Z","end": "2022-06-05T10:02:00.000Z"}],
    "stringChoiceValues": [{"name": "productType","value": "LFR"},{"name": "processingLevel","value": "LEVEL2"},{"name": "timeliness","value": "Near+Real+Time"},{"name": "orbitDirection","value": "descending"}]
}

query_202205 = {
    "datasetId": "EO:ESA:DAT:SENTINEL-3:OL_2_LFR___",
    "boundingBoxValues": [{"name": "bbox","bbox": [0.8592794892475406,46.05897584821409,4.208065051226144,48.04894076163025]}],
    "dateRangeSelectValues": [{"name": "position", "start": "2022-05-09T10:00:00.000Z","end": "2022-05-09T10:02:00.000Z"}],
    "stringChoiceValues": [{"name": "productType","value": "LFR"},{"name": "processingLevel","value": "LEVEL2"},{"name": "timeliness","value": "Near+Real+Time"},{"name": "orbitDirection","value": "descending"}]
}

query_202204 = {
    "datasetId": "EO:ESA:DAT:SENTINEL-3:OL_2_LFR___",
    "boundingBoxValues": [{"name": "bbox","bbox": [0.8592794892475406,46.05897584821409,4.208065051226144,48.04894076163025]}],
    "dateRangeSelectValues": [{"name": "position", "start": "2022-04-12T09:59:00.000Z","end": "2022-04-12T10:02:00.000Z"}],
    "stringChoiceValues": [{"name": "productType","value": "LFR"},{"name": "processingLevel","value": "LEVEL2"},{"name": "timeliness","value": "Near+Real+Time"},{"name": "orbitDirection","value": "descending"}]
}

queries = [query_202207, query_202206, query_202205, query_202204]
In [59]:
c=hda.Client(debug=True)

dir_name = workdir + '/data/S3'
if os.path.isdir(dir_name):
    if not os.listdir(dir_name):
        for query in queries:
            matches = c.search(query)
            matches.download()
            for match in matches.results:
                fdst = match['filename']
                print(f"Found: {fdst}")        
                with zipfile.ZipFile(fdst, 'r') as zip_ref:
                    zip_ref.extractall(path)
                    print(f'Unzipping of product {fdst} finished.')
                os.remove(fdst)
    else:    
        print("the S3 data has already been downloaded")
else:
    print("the given directory doesn't exist")
the S3 data has already been downloaded

Data exploration and pre-processing ¶

Importing libraries ¶

For this section we will use libraries to work with multidimensional data arrays stored in NetCDF4 format. We will also use some packages to plot them and create maps.

In [60]:
# To work with multidimentional arrays
import netCDF4
import xarray as xr
import numpy

# To create plots and maps
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default = "jupyterlab"

import matplotlib.pyplot as plt
import matplotlib.colors
from matplotlib import animation
from matplotlib.axes import Axes

from shapely import geometry, vectorized
import eumartools

import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.feature as cfeature
from cartopy.mpl.geoaxes import GeoAxes
GeoAxes._pcolormesh_patched = Axes.pcolormesh

import holoviews as hv
from holoviews import opts, dim

import geoviews as gv
import geoviews.feature as gf
from cartopy import crs
gv.extension('bokeh')

from IPython.display import HTML
import coffee

import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

Aditionally we will make use of visualize_pcolormesh function to plot a multidimentional array and plot it into a map. Here we will plot the Sentinel 3 cloud mask, which is very useful for obtaining a high resolution representation of clouds. First, we define are region of interest (roi) followed by the selection of latitude and longitue.

In [33]:
roi =  [[-30.0, 30.0], [-30.0, 60.0], [30.0, 60.0], [30.0, 30.0], [-30.0, 30.0]]

# Read in the coordinate data and build a spatial mask
geo_fid = xr.open_dataset('data/S3/S3B_20220412/geo_coordinates.nc')
lat = geo_fid.get('latitude').data
lon = geo_fid.get('longitude').data
geo_fid.close()

# Now check the flag content for our polygon
flag_file = os.path.join('data/S3/S3B_20220412/lqsf.nc')
flag_variable = 'LQSF'
flags_to_use = ['CLOUD']
flag_mask = eumartools.flag_mask(flag_file, flag_variable, flags_to_use)

point_mask = vectorized.contains(geometry.Polygon(roi), lon,lat)
In [34]:
date = geo_fid.attrs['product_name'][16:24]
time = geo_fid.attrs['product_name'][25:29]; 

coffee.geoviz_s3(flag_mask, lat, lon, 'Clouds over France on: ' + date + ' ' + time)

Loading data ¶

Xarray extends numpy structures to handle labeled data. It makes easy to work with coordinates, attributes and multiple dimensions which makes them a powerful tool for gridded geographical data. For example we can easily subset the worldwide data downladed to a specific AOI based on geographic coordinates. For our exampe the data is already cropped, but a example is given on how to subset a non-croped dataset.

Tip: Sometimes the dataset may be too large and need to be downloaded in multiple parts. If that's the case, the following lines of code will open all the .nc files in a folder as an array and save them to a new file merged.nc
In [61]:
data = xr.open_mfdataset(os.path.join("data/ERA5/",'*nc'),combine = 'nested', concat_dim="time")
merged_file='data/ERA5/merged.nc'
if not os.path.exists(merged_file):
    data.to_netcdf('data/ERA5/merged.nc')
    data=data.tcc
else:
    print('Data already merged')
    data=xr.open_dataset('data/ERA5/merged.nc')
    data=data.tcc
Data already merged

The dataset has a depth of 38184 which is every hour of the time range. Spatially it is composed of a size of 55 x 55, representing longitude and latitude squared over France.

In [37]:
data
Out[37]:
<xarray.DataArray 'tcc' (time: 38184, latitude: 55, longitude: 55)>
dask.array<concatenate, shape=(38184, 55, 55), dtype=float32, chunksize=(7656, 55, 55), chunktype=numpy.ndarray>
Coordinates:
  * longitude  (longitude) float32 -4.75 -4.5 -4.25 -4.0 ... 8.0 8.25 8.5 8.75
  * latitude   (latitude) float32 53.75 53.5 53.25 53.0 ... 40.75 40.5 40.25
  * time       (time) datetime64[ns] 2017-01-01 ... 2021-11-29T23:00:00
Attributes:
    units:          (0 - 1)
    long_name:      Total cloud cover
    standard_name:  cloud_area_fraction
xarray.DataArray
'tcc'
  • time: 38184
  • latitude: 55
  • longitude: 55
  • dask.array<chunksize=(7632, 55, 55), meta=np.ndarray>
    Array Chunk
    Bytes 440.62 MiB 88.35 MiB
    Shape (38184, 55, 55) (7656, 55, 55)
    Count 15 Tasks 5 Chunks
    Type float32 numpy.ndarray
    55 55 38184
    • longitude
      (longitude)
      float32
      -4.75 -4.5 -4.25 ... 8.25 8.5 8.75
      array([-4.75, -4.5 , -4.25, -4.  , -3.75, -3.5 , -3.25, -3.  , -2.75, -2.5 ,
             -2.25, -2.  , -1.75, -1.5 , -1.25, -1.  , -0.75, -0.5 , -0.25,  0.  ,
              0.25,  0.5 ,  0.75,  1.  ,  1.25,  1.5 ,  1.75,  2.  ,  2.25,  2.5 ,
              2.75,  3.  ,  3.25,  3.5 ,  3.75,  4.  ,  4.25,  4.5 ,  4.75,  5.  ,
              5.25,  5.5 ,  5.75,  6.  ,  6.25,  6.5 ,  6.75,  7.  ,  7.25,  7.5 ,
              7.75,  8.  ,  8.25,  8.5 ,  8.75], dtype=float32)
    • latitude
      (latitude)
      float32
      53.75 53.5 53.25 ... 40.5 40.25
      units :
      degrees_north
      long_name :
      latitude
      array([53.75, 53.5 , 53.25, 53.  , 52.75, 52.5 , 52.25, 52.  , 51.75, 51.5 ,
             51.25, 51.  , 50.75, 50.5 , 50.25, 50.  , 49.75, 49.5 , 49.25, 49.  ,
             48.75, 48.5 , 48.25, 48.  , 47.75, 47.5 , 47.25, 47.  , 46.75, 46.5 ,
             46.25, 46.  , 45.75, 45.5 , 45.25, 45.  , 44.75, 44.5 , 44.25, 44.  ,
             43.75, 43.5 , 43.25, 43.  , 42.75, 42.5 , 42.25, 42.  , 41.75, 41.5 ,
             41.25, 41.  , 40.75, 40.5 , 40.25], dtype=float32)
    • time
      (time)
      datetime64[ns]
      2017-01-01 ... 2021-11-29T23:00:00
      long_name :
      time
      array(['2017-01-01T00:00:00.000000000', '2017-01-01T01:00:00.000000000',
             '2017-01-01T02:00:00.000000000', ..., '2021-11-29T21:00:00.000000000',
             '2021-11-29T22:00:00.000000000', '2021-11-29T23:00:00.000000000'],
            dtype='datetime64[ns]')
  • units :
    (0 - 1)
    long_name :
    Total cloud cover
    standard_name :
    cloud_area_fraction

One of the main advantages of xarray is the ability to access and subset data based on spatial or temporal coordinates. The next block uses a spatial subset and creates a new array containing only information about our AOI. You can use the following lines to crop a dataset to a given specific area. In our case it will not make any difference as the dataset is already cropped. To establish the AOI just change the limits of longitude and latitude.

In [38]:
file_assigned = data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180)).sortby('longitude')
data=file_assigned.where((file_assigned.latitude < 54) & (file_assigned.latitude > 40) & (file_assigned.longitude < 9) & (file_assigned.longitude > -5),drop=True)
In [39]:
data
Out[39]:
<xarray.DataArray 'tcc' (time: 38184, latitude: 55, longitude: 55)>
dask.array<where, shape=(38184, 55, 55), dtype=float32, chunksize=(7656, 55, 55), chunktype=numpy.ndarray>
Coordinates:
  * longitude  (longitude) float32 -4.75 -4.5 -4.25 -4.0 ... 8.0 8.25 8.5 8.75
  * latitude   (latitude) float32 53.75 53.5 53.25 53.0 ... 40.75 40.5 40.25
  * time       (time) datetime64[ns] 2017-01-01 ... 2021-11-29T23:00:00
Attributes:
    units:          (0 - 1)
    long_name:      Total cloud cover
    standard_name:  cloud_area_fraction
xarray.DataArray
'tcc'
  • time: 38184
  • latitude: 55
  • longitude: 55
  • dask.array<chunksize=(7632, 55, 55), meta=np.ndarray>
    Array Chunk
    Bytes 440.62 MiB 88.35 MiB
    Shape (38184, 55, 55) (7656, 55, 55)
    Count 21 Tasks 5 Chunks
    Type float32 numpy.ndarray
    55 55 38184
    • longitude
      (longitude)
      float32
      -4.75 -4.5 -4.25 ... 8.25 8.5 8.75
      array([-4.75, -4.5 , -4.25, -4.  , -3.75, -3.5 , -3.25, -3.  , -2.75, -2.5 ,
             -2.25, -2.  , -1.75, -1.5 , -1.25, -1.  , -0.75, -0.5 , -0.25,  0.  ,
              0.25,  0.5 ,  0.75,  1.  ,  1.25,  1.5 ,  1.75,  2.  ,  2.25,  2.5 ,
              2.75,  3.  ,  3.25,  3.5 ,  3.75,  4.  ,  4.25,  4.5 ,  4.75,  5.  ,
              5.25,  5.5 ,  5.75,  6.  ,  6.25,  6.5 ,  6.75,  7.  ,  7.25,  7.5 ,
              7.75,  8.  ,  8.25,  8.5 ,  8.75], dtype=float32)
    • latitude
      (latitude)
      float32
      53.75 53.5 53.25 ... 40.5 40.25
      units :
      degrees_north
      long_name :
      latitude
      array([53.75, 53.5 , 53.25, 53.  , 52.75, 52.5 , 52.25, 52.  , 51.75, 51.5 ,
             51.25, 51.  , 50.75, 50.5 , 50.25, 50.  , 49.75, 49.5 , 49.25, 49.  ,
             48.75, 48.5 , 48.25, 48.  , 47.75, 47.5 , 47.25, 47.  , 46.75, 46.5 ,
             46.25, 46.  , 45.75, 45.5 , 45.25, 45.  , 44.75, 44.5 , 44.25, 44.  ,
             43.75, 43.5 , 43.25, 43.  , 42.75, 42.5 , 42.25, 42.  , 41.75, 41.5 ,
             41.25, 41.  , 40.75, 40.5 , 40.25], dtype=float32)
    • time
      (time)
      datetime64[ns]
      2017-01-01 ... 2021-11-29T23:00:00
      long_name :
      time
      array(['2017-01-01T00:00:00.000000000', '2017-01-01T01:00:00.000000000',
             '2017-01-01T02:00:00.000000000', ..., '2021-11-29T21:00:00.000000000',
             '2021-11-29T22:00:00.000000000', '2021-11-29T23:00:00.000000000'],
            dtype='datetime64[ns]')
  • units :
    (0 - 1)
    long_name :
    Total cloud cover
    standard_name :
    cloud_area_fraction

Plotting temporal data ¶

It is possible to use the dimensional nature of xarray to perform operations on the data. For example, using temporal labeling we can calculate the mean TCC for each timestamp and analyze the corresponding time series to identify patterns and other behaviors of interest. Since we have very long time series, the use of interactive charts allows us to explore different time scales and move throughout the entire time window. You can check on the create_timeseries function to see how it works. Basically we pass the time information as x axis and the mean value over our AOI as y axis.

In [40]:
coffee.create_timeseries(x=list(data.time.values.astype('str')), y=list(data.mean(axis=(1,2)).values))

Plotting geographical data ¶

Using the visualize_colormesh function, we can select a specific moment of interest and plot it on a reference map to analyze the spatial distribution of clouds in that particular snapshot.

In [41]:
coffee.visualize_pcolormesh(data_array=data.sel(time='2017-01-01T02:00:00.000000000'),
                     longitude=data.longitude,
                     latitude=data.latitude,
                     projection=ccrs.PlateCarree(),
                     color_scale='Blues',
                     unit='%',
                     long_name= 'Total Cloud Cover ' + str(data.time.data[1]),
                     vmin=0, 
                     vmax=1,
                     lonmin=data.longitude.min().data,
                     lonmax=data.longitude.max().data,
                     latmin=data.latitude.min().data,
                     latmax=data.latitude.max().data,
                     set_global=False)
Out[41]:
(<Figure size 1440x720 with 2 Axes>,
 <GeoAxesSubplot:title={'center':'Total Cloud Cover 2017-01-01T01:00:00.000000000'}>)

Animating spatio-temporal data ¶

Going even further, we can use the temporal dimension of the data to create animations on temporal subsets. In this case, the data can be thought of as a video where every hour is a new frame. Using holoviz library we can browse along each frame. We choose the first week of the year. Longer time periods can be selected but it will demand more time and memory.

In [43]:
subset = data.sel(time=slice('2017-01-01T01:00:00.000000000', '2017-01-07T23:00:00.000000000'))
# Specify the dataset, its coordinates and requested variable 
dataset = gv.Dataset(subset.load(), ['longitude', 'latitude', 'time'], 'mag', crs=crs.PlateCarree())
images = dataset.to(gv.Image,dynamic=False)
# Slider location
hv.output(widget_location='bottom')
# Create stack of images grouped by time
images.opts(active_tools=['wheel_zoom', 'pan'], cmap="blues",colorbar=True, width=500, height=500, clim=(0.1,1.0))
Out[43]:

Now we want to transform the data in such a way that a deep learning model can handle it. For this we first convert the xarray into a numpy array, without labels where only the TCC values ​​are stored. This matrix has 3 dimensions: first the temporal one, in this case 38184, each of these corresponds to a different time, the remaining two dimensions correspond to the latitude-longitude grid.

In [40]:
tcc_array=data.to_numpy()
tcc_array.shape
Out[40]:
(38184, 55, 55)

This means that, as it is, we have a single sequence, or video, of n hours or frames. For our case study, we will transform this in such a way that each day is just one video consisting of 24 frames. For this we reshape the array, where the new dimensions are: number of dates, 24 hours and two dimensions for the geographic grid. Also. An added dimension is created for the grayscale nature of the images.

In [41]:
samples=int(len(data.time)/24)
dataset=tcc_array.reshape([samples,24,55,55])
dataset = numpy.expand_dims(dataset, axis=-1)
dataset.shape
Out[41]:
(1591, 24, 55, 55, 1)

From the new reshaped dataset, select data to train and evaluate the model by randomly selecting 90% of the sequences to train and 10% to evaluate.

In [42]:
# Split into train and validation sets using indexing to optimize memory.
indexes = numpy.arange(dataset.shape[0])
numpy.random.shuffle(indexes)
train_index = indexes[: int(0.9 * dataset.shape[0])]
val_index = indexes[int(0.9 * dataset.shape[0]) :]
train_dataset = dataset[train_index]
val_dataset = dataset[val_index]

The model to be trained will predict the next frame of the video based on the previous one. If we consider $Y_t$ the frame to be predicted, we need a predictor $X_{t-1}$. This is defined in the create_shifted_frames function, which is then used to create $X_{train}$, $Y_{train}$, $X_{validation}$ and $Y_{validation}$ datasets.

In [43]:
def create_shifted_frames(data):
    x = data[:, 0 : data.shape[1] - 1, :, :]
    y = data[:, 1 : data.shape[1], :, :]
    return x, y
x_train, y_train = create_shifted_frames(train_dataset)
x_val, y_val = create_shifted_frames(val_dataset)

The following code block randomly selects one of the training samples and plots its 24 frames.

In [44]:
# Construct a figure on which we will visualize the images.
fig, axes = plt.subplots(4,6 , figsize=(12, 10))

# Plot each of the sequential images for one random data example.
data_choice = numpy.random.choice(range(len(train_dataset)), size=1)[0]
for idx, ax in enumerate(axes.flat):
    ax.imshow(numpy.squeeze(train_dataset[data_choice][idx]), cmap="GnBu")
    ax.set_title(f"Frame {idx + 1}")
    ax.axis("off")

# Print information and display the figure.
print(f"Displaying frames for example {data_choice}.")
plt.show()
Displaying frames for example 368.

Next-frame prediction ¶

Importing libraries ¶

In [62]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import  save_model, load_model

import io
import imageio
from IPython.display import  display
from ipywidgets import widgets, HBox

Defining the model ¶

The basis of the model for this exercise is the ConvLSTM2D architecture (Shi et al. 2015) in keras implementation (https://keras.io/examples/vision/conv_lstm/) . This architecture is an extension of the type of short-term memory (LSTM) networks that are commonly used to process data streams, as it retains feedback connections. The main difference in the ConvLSTM architectures is the way operations are performed on the spatial domain of the data, it is done through 2D convolutions instead of simple matrix operations. Additionally, Batch normalization is used between layers to stabilize and speed up training. Finally, a Conv3D layer is introduced, which performs convolutions on volumes, to take into account the space-time nature of the outputs.

In [46]:
inp = layers.Input(shape=(None, *x_train.shape[2:]))

x = layers.ConvLSTM2D(
    filters=32,
    kernel_size=(5, 5),
    padding="same",
    return_sequences=True,
    activation="relu",
)(inp)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
    filters=32,
    kernel_size=(3, 3),
    padding="same",
    return_sequences=True,
    activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
    filters=32,
    kernel_size=(1, 1),
    padding="same",
    return_sequences=True,
    activation="relu",
)(x)
x = layers.Conv3D(
    filters=1, kernel_size=(3, 3, 3), activation="sigmoid", padding="same"
)(x)


model = keras.models.Model(inp, x)
model.compile(loss=keras.losses.MeanSquaredError(), optimizer=keras.optimizers.Adam(learning_rate=0.000001))
2022-09-04 10:28:02.360565: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

Tranning the model ¶

Before training the model, we define a stop early callback that will end training if the loss of model validations reaches a minimum value and stabilizes for 10 epochs. Training a Deep Learning model is always time and hardware intensive. In this case we provide you with the necessary lines to train a new model but, to avoid this step, a pre-trained model is also given. The next block will train a model just if there is no model inside model folder.

In [45]:
if os.path.exists('model/Model.h5'):
    print('Model already trained')
else:
    # Define some callbacks to improve training.
    early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)

    # Define modifiable training hyperparameters.
    epochs = 100
    batch_size = 1

    # Fit the model to the training data.
    model.fit(x_train,
        y_train,
        epochs=epochs,
        validation_data=(x_val, y_val),
        callbacks=[early_stopping],
    )
Model already trained

After the training is complete, the model is saved to an h5 file that stores weights, architecture, and parameters in a single file.

In [46]:
if os.path.exists('model/Model.h5'):
    print('Model already saved')
else:
    save_model(model,workdir+'/model.h5')
Model already saved

Making predictions ¶

The following code block randomly selects a sample from the validation dataset and predicts 12 final truth frames with the predicted ones.

In [49]:
model = load_model('model/Model.h5', compile = True)
model.summary()
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, None, 55, 55, 1)  0         
                             ]                                   
                                                                 
 conv_lstm2d_3 (ConvLSTM2D)  (None, None, 55, 55, 32)  105728    
                                                                 
 batch_normalization_2 (Batc  (None, None, 55, 55, 32)  128      
 hNormalization)                                                 
                                                                 
 conv_lstm2d_4 (ConvLSTM2D)  (None, None, 55, 55, 32)  73856     
                                                                 
 batch_normalization_3 (Batc  (None, None, 55, 55, 32)  128      
 hNormalization)                                                 
                                                                 
 conv_lstm2d_5 (ConvLSTM2D)  (None, None, 55, 55, 32)  8320      
                                                                 
 conv3d_1 (Conv3D)           (None, None, 55, 55, 1)   865       
                                                                 
=================================================================
Total params: 189,025
Trainable params: 188,897
Non-trainable params: 128
_________________________________________________________________
In [50]:
# Select a few random examples from the dataset.
examples = val_dataset[numpy.random.choice(range(len(val_dataset)), size=1)]

# Iterate over the examples and predict the frames.
predicted_videos = []
for example in examples:
    # Pick the first/last ten frames from the example.
    frames = example[:12, ...]
    original_frames = example[12:, ...]
    new_predictions = numpy.zeros(shape=(12, *frames[0].shape))

    # Predict a new set of 10 frames.
    for i in range(12):
        # Extract the model's prediction and post-process it.
        frames = example[: 12 + i + 1, ...]
        new_prediction = model.predict(numpy.expand_dims(frames, axis=0))
        new_prediction = numpy.squeeze(new_prediction, axis=0)
        predicted_frame = numpy.expand_dims(new_prediction[-1, ...], axis=0)

        # Extend the set of prediction frames.
        new_predictions[i] = predicted_frame

    # Create and save GIFs for each of the ground truth/prediction images.
    for frame_set in [original_frames, new_predictions]:
        # Construct a GIF from the selected video frames.
        current_frames = numpy.squeeze(frame_set)
        current_frames = current_frames[..., numpy.newaxis] * numpy.ones(3)
        current_frames = (current_frames * 255).astype(numpy.uint8)
        current_frames = list(current_frames)

        # Construct a GIF from the frames.
        with io.BytesIO() as gif:
            imageio.mimsave(gif, current_frames, "GIF", fps=5)
            predicted_videos.append(gif.getvalue())

# Display the videos.
print(" Truth\tPrediction")
for i in range(0, len(predicted_videos), 2):
    # Construct and display an `HBox` with the ground truth and prediction.
    box = HBox(
        [
            widgets.Image(value=predicted_videos[i],width=300,height=200),
            widgets.Image(value=predicted_videos[i + 1],width=300,height=200),
        ]
    )
    display(box)
1/1 [==============================] - 2s 2s/step
1/1 [==============================] - 1s 685ms/step
1/1 [==============================] - 0s 182ms/step
1/1 [==============================] - 0s 184ms/step
1/1 [==============================] - 0s 223ms/step
1/1 [==============================] - 0s 199ms/step
1/1 [==============================] - 0s 250ms/step
1/1 [==============================] - 0s 255ms/step
1/1 [==============================] - 0s 227ms/step
1/1 [==============================] - 0s 272ms/step
1/1 [==============================] - 0s 300ms/step
1/1 [==============================] - 0s 280ms/step
 Truth	Prediction
HBox(children=(Image(value=b'GIF89a7\x007\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…

Finally we use a date outside the original timelapse and make a prediction on it. You can download any day you want for this. We provide you with a random date outside the timestamp.

In [49]:
validation_date=xr.open_dataset('data/ERA5/validation.nc')
validation_date=data.to_numpy()
nvalidation_date = numpy.expand_dims(validation_date, axis=-1)
Tip: Remember that you can change the duration of the training and the expected frames.
In [48]:
# Select a few random examples from the dataset.
examples = [nvalidation_date]

# Iterate over the examples and predict the frames.
predicted_videos = []
for example in examples:
    # Pick the first/last ten frames from the example.
    frames = example[:12, ...]
    original_frames = example[12:, ...]
    new_predictions = numpy.zeros(shape=(12, *frames[0].shape))

    # Predict a new set of 10 frames.
    for i in range(12):
        # Extract the model's prediction and post-process it.
        frames = example[: 12 + i + 1, ...]
        new_prediction = model.predict(numpy.expand_dims(frames, axis=0))
        new_prediction = numpy.squeeze(new_prediction, axis=0)
        predicted_frame = numpy.expand_dims(new_prediction[-1, ...], axis=0)

        # Extend the set of prediction frames.
        new_predictions[i] = predicted_frame

    # Create and save GIFs for each of the ground truth/prediction images.
    for frame_set in [original_frames, new_predictions]:
        # Construct a GIF from the selected video frames.
        current_frames = numpy.squeeze(frame_set)
        current_frames = current_frames[..., numpy.newaxis] * numpy.ones(3)
        current_frames = (current_frames * 255).astype(numpy.uint8)
        current_frames = list(current_frames)

        # Construct a GIF from the frames.
        with io.BytesIO() as gif:
            imageio.mimsave(gif, current_frames, "GIF", fps=5)
            predicted_videos.append(gif.getvalue())

# Display the videos.
print(" Truth\tPrediction")
for i in range(0, len(predicted_videos), 2):
    # Construct and display an `HBox` with the ground truth and prediction.
    box = HBox(
        [
            widgets.Image(value=predicted_videos[i],width=300,height=200),
            widgets.Image(value=predicted_videos[i + 1],width=300,height=200),
        ]
    )
    display(box)
1/1 [==============================] - 0s 120ms/step
1/1 [==============================] - 0s 133ms/step
1/1 [==============================] - 0s 97ms/step
1/1 [==============================] - 0s 125ms/step
1/1 [==============================] - 0s 150ms/step
1/1 [==============================] - 0s 134ms/step
1/1 [==============================] - 0s 157ms/step
1/1 [==============================] - 0s 127ms/step
1/1 [==============================] - 0s 142ms/step
1/1 [==============================] - 0s 158ms/step
1/1 [==============================] - 0s 148ms/step
1/1 [==============================] - 0s 159ms/step
 Truth	Prediction
HBox(children=(Image(value=b'GIF89a7\x007\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…

References¶

  • Shi, X., Chen, Z., Wang, H., Yeung, D. Y., Wong, W. K., & Woo, W. C. (2015). Convolutional LSTM network: A machine learning approach for precipitation nowcasting. Advances in neural information processing systems, 28.
  • https://keras.io/examples/vision/conv_lstm/

Now is your turn ¶

With tre tools we have give you, now you can start exploring and expanding. Here we propose some challenges that can be done with some changes to this notebook.

  1. Symbology and mapping - Change the maps and graphic layouts.If you want to go further plot the predicted numpy arrays into maps, or even animations rather than gifs.
  2. Prediction - your own date: Download a particular day you are interested in and use de trained model to make predictions.
  3. Training - Pick a different date, or even a different variable and train a new model.
  4. Trining advanced- Changed model parameters, loss function, number of layers, number of filters, epochs etc.