Authors: Jose Gomez, Oscar Narvaez, Max Aragon
Affiliation: Department of Geoinformatics, University of Salzburg
Copyright: Ⓒ 2022
License: MIT
In this Jupyter Notebook you will learn how to query ERA5 and Sentinel-3 data from WEkEO, explore it, and pre-process it for predicting the spatiotemporal variability of total cloud cover over France.
ERA5 is a reanalysis dataset providing hourly estimates of a large number of atmospheric, land and oceanic climate variables. The data covers the Earth on a 30km grid and resolve the atmosphere using 137 levels from the surface up to a height of 80km. ERA5 includes information about uncertainties for all variables at reduced spatial and temporal resolutions.
From this dataset we use the Total Cloud Cover (TCC) estimation. This parameter is the proportion of a grid box covered by cloud. Total cloud cover is a single level field calculated from the cloud occurring at different model levels through the atmosphere. Assumptions are made about the degree of overlap/randomness between clouds at different heights. Cloud fractions vary from 0 to 1.
In addition, we use Sentinel-3 Ocean and Land Colour Instrument (OLCI) cloud mask product to compare TCC values. OLCI has a resolution of 300m which is 100 times more detailed than ERA5, thus allowing us to evaluate the performance of the reanalysis with finer observations.
from IPython.display import YouTubeVideo
YouTubeVideo('FAGobvUGl24', width=800, height=450)
We will use the Harmonized Data Access service to search and download our data. For this, Wekeo has created a python client for the API which makes this process easier. Next we will import that Client, define the query json and send a request for downloading ERA5 TCC data. Currently this data is provided world-wide, so it must be cropped according to a specific AOI.
import hda # Wekeo's python Client for HDA
import os # To manage folder basic operations
import zipfile # to unzip files
from datetime import datetime # Python datetime library
import warnings
warnings.filterwarnings('ignore')
Next we need to define a JSON: a dictionary of data to communicate with the API. This JSON contains information such as ERA5 collection ID, our variable of interest and the time range we look for. In the following block we generate lists of years, months and days to facilitate hanlding. For this example we will use hourly data from 2017 to 2021, including all months of the year. We will also define the directory in which we want the data to be downloaded.
#Sequences for years, months and days, note that you need to add one unit to the desired number.
Years=list(map(str,range(2017,2022,1)))
months=list(map(str,range(1,13,1)))
days=list(map(str,range(1,31,1)))
# Working dir (download data)
workdir='.'
Years #How the years sequence looks like
['2017', '2018', '2019', '2020', '2021']
query = {
"datasetId": "EO:ECMWF:DAT:REANALYSIS_ERA5_SINGLE_LEVELS", # This is the ID of ERA5 dataset.
"multiStringSelectValues": [
{
"name": "variable",
"value": [
"total_cloud_cover" # The name of our variable of interest.
]
},
{
"name": "product_type",
"value": [
"reanalysis"
]
},
{
"name": "year",
"value": Years # We use the generated lists to fill these fields.
},
{
"name": "month",
"value": months
},
{
"name": "day",
"value": days
},
{
"name": "time",
"value": [
"00:00",
"01:00",
"02:00",
"03:00",
"04:00",
"05:00",
"06:00",
"07:00",
"08:00",
"09:00",
"10:00",
"11:00",
"12:00",
"13:00",
"14:00",
"15:00",
"16:00",
"17:00",
"18:00",
"19:00",
"20:00",
"21:00",
"22:00",
"23:00"
]
}
],
"stringChoiceValues": [
{
"name": "format",
"value": "netcdf"
}
]
}
username=''
password=''
HOME=os.environ['HOME']
fileString=(f'url: https://wekeo-broker.apps.mercator.dpi.wekeo.eu/databroker \nuser: {username} \npassword: {password}')
file_path=os.path.join(HOME,'.hdarc')
if os.path.exists(file_path):
os.remove(file_path)
with open(file_path,'w') as f:
f.write(fileString)
Now we create an hda client instance and pass the query. The next block will search for the data and download it to your working directory, it might take long as it will only stop when all the data is downloaded.
dir_name = workdir + '/data/ERA5'
if os.path.isdir(dir_name):
if not os.listdir(dir_name):
c=hda.Client(debug=True) # Creating the client
matches=c.search(query) # Looking for data based on query json
os.chdir(workdir)
matches.download() # Actually making the download
else:
print("the ERA5 data has already been downloaded")
else:
print("the given directory doesn't exist")
the ERA5 data has already been downloaded
query_202207 = {
"datasetId": "EO:ESA:DAT:SENTINEL-3:OL_2_LFR___",
"boundingBoxValues": [{"name": "bbox","bbox": [0.8592794892475406,46.05897584821409,4.208065051226144,48.04894076163025]}],
"dateRangeSelectValues": [{"name": "position", "start": "2022-07-29T10:00:00.000Z","end": "2022-07-29T10:02:00.000Z"}],
"stringChoiceValues": [{"name": "productType","value": "LFR"},{"name": "processingLevel","value": "LEVEL2"},{"name": "timeliness","value": "Near+Real+Time"},{"name": "orbitDirection","value": "descending"}]
}
query_202206 = {
"datasetId": "EO:ESA:DAT:SENTINEL-3:OL_2_LFR___",
"boundingBoxValues": [{"name": "bbox","bbox": [0.8592794892475406,46.05897584821409,4.208065051226144,48.04894076163025]}],
"dateRangeSelectValues": [{"name": "position", "start": "2022-06-05T10:00:00.000Z","end": "2022-06-05T10:02:00.000Z"}],
"stringChoiceValues": [{"name": "productType","value": "LFR"},{"name": "processingLevel","value": "LEVEL2"},{"name": "timeliness","value": "Near+Real+Time"},{"name": "orbitDirection","value": "descending"}]
}
query_202205 = {
"datasetId": "EO:ESA:DAT:SENTINEL-3:OL_2_LFR___",
"boundingBoxValues": [{"name": "bbox","bbox": [0.8592794892475406,46.05897584821409,4.208065051226144,48.04894076163025]}],
"dateRangeSelectValues": [{"name": "position", "start": "2022-05-09T10:00:00.000Z","end": "2022-05-09T10:02:00.000Z"}],
"stringChoiceValues": [{"name": "productType","value": "LFR"},{"name": "processingLevel","value": "LEVEL2"},{"name": "timeliness","value": "Near+Real+Time"},{"name": "orbitDirection","value": "descending"}]
}
query_202204 = {
"datasetId": "EO:ESA:DAT:SENTINEL-3:OL_2_LFR___",
"boundingBoxValues": [{"name": "bbox","bbox": [0.8592794892475406,46.05897584821409,4.208065051226144,48.04894076163025]}],
"dateRangeSelectValues": [{"name": "position", "start": "2022-04-12T09:59:00.000Z","end": "2022-04-12T10:02:00.000Z"}],
"stringChoiceValues": [{"name": "productType","value": "LFR"},{"name": "processingLevel","value": "LEVEL2"},{"name": "timeliness","value": "Near+Real+Time"},{"name": "orbitDirection","value": "descending"}]
}
queries = [query_202207, query_202206, query_202205, query_202204]
c=hda.Client(debug=True)
dir_name = workdir + '/data/S3'
if os.path.isdir(dir_name):
if not os.listdir(dir_name):
for query in queries:
matches = c.search(query)
matches.download()
for match in matches.results:
fdst = match['filename']
print(f"Found: {fdst}")
with zipfile.ZipFile(fdst, 'r') as zip_ref:
zip_ref.extractall(path)
print(f'Unzipping of product {fdst} finished.')
os.remove(fdst)
else:
print("the S3 data has already been downloaded")
else:
print("the given directory doesn't exist")
the S3 data has already been downloaded
For this section we will use libraries to work with multidimensional data arrays stored in NetCDF4 format. We will also use some packages to plot them and create maps.
# To work with multidimentional arrays
import netCDF4
import xarray as xr
import numpy
# To create plots and maps
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default = "jupyterlab"
import matplotlib.pyplot as plt
import matplotlib.colors
from matplotlib import animation
from matplotlib.axes import Axes
from shapely import geometry, vectorized
import eumartools
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.feature as cfeature
from cartopy.mpl.geoaxes import GeoAxes
GeoAxes._pcolormesh_patched = Axes.pcolormesh
import holoviews as hv
from holoviews import opts, dim
import geoviews as gv
import geoviews.feature as gf
from cartopy import crs
gv.extension('bokeh')
from IPython.display import HTML
import coffee
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)
Aditionally we will make use of visualize_pcolormesh
function to plot a multidimentional array and plot it into a map. Here we will plot the Sentinel 3 cloud mask, which is very useful for obtaining a high resolution representation of clouds. First, we define are region of interest (roi) followed by the selection of latitude and longitue.
roi = [[-30.0, 30.0], [-30.0, 60.0], [30.0, 60.0], [30.0, 30.0], [-30.0, 30.0]]
# Read in the coordinate data and build a spatial mask
geo_fid = xr.open_dataset('data/S3/S3B_20220412/geo_coordinates.nc')
lat = geo_fid.get('latitude').data
lon = geo_fid.get('longitude').data
geo_fid.close()
# Now check the flag content for our polygon
flag_file = os.path.join('data/S3/S3B_20220412/lqsf.nc')
flag_variable = 'LQSF'
flags_to_use = ['CLOUD']
flag_mask = eumartools.flag_mask(flag_file, flag_variable, flags_to_use)
point_mask = vectorized.contains(geometry.Polygon(roi), lon,lat)
date = geo_fid.attrs['product_name'][16:24]
time = geo_fid.attrs['product_name'][25:29];
coffee.geoviz_s3(flag_mask, lat, lon, 'Clouds over France on: ' + date + ' ' + time)
Xarray extends numpy structures to handle labeled data. It makes easy to work with coordinates, attributes and multiple dimensions which makes them a powerful tool for gridded geographical data. For example we can easily subset the worldwide data downladed to a specific AOI based on geographic coordinates. For our exampe the data is already cropped, but a example is given on how to subset a non-croped dataset.
data = xr.open_mfdataset(os.path.join("data/ERA5/",'*nc'),combine = 'nested', concat_dim="time")
merged_file='data/ERA5/merged.nc'
if not os.path.exists(merged_file):
data.to_netcdf('data/ERA5/merged.nc')
data=data.tcc
else:
print('Data already merged')
data=xr.open_dataset('data/ERA5/merged.nc')
data=data.tcc
Data already merged
The dataset has a depth of 38184 which is every hour of the time range. Spatially it is composed of a size of 55 x 55, representing longitude and latitude squared over France.
data
<xarray.DataArray 'tcc' (time: 38184, latitude: 55, longitude: 55)> dask.array<concatenate, shape=(38184, 55, 55), dtype=float32, chunksize=(7656, 55, 55), chunktype=numpy.ndarray> Coordinates: * longitude (longitude) float32 -4.75 -4.5 -4.25 -4.0 ... 8.0 8.25 8.5 8.75 * latitude (latitude) float32 53.75 53.5 53.25 53.0 ... 40.75 40.5 40.25 * time (time) datetime64[ns] 2017-01-01 ... 2021-11-29T23:00:00 Attributes: units: (0 - 1) long_name: Total cloud cover standard_name: cloud_area_fraction
One of the main advantages of xarray is the ability to access and subset data based on spatial or temporal coordinates. The next block uses a spatial subset and creates a new array containing only information about our AOI. You can use the following lines to crop a dataset to a given specific area. In our case it will not make any difference as the dataset is already cropped. To establish the AOI just change the limits of longitude and latitude.
file_assigned = data.assign_coords(longitude=(((data.longitude + 180) % 360) - 180)).sortby('longitude')
data=file_assigned.where((file_assigned.latitude < 54) & (file_assigned.latitude > 40) & (file_assigned.longitude < 9) & (file_assigned.longitude > -5),drop=True)
data
<xarray.DataArray 'tcc' (time: 38184, latitude: 55, longitude: 55)> dask.array<where, shape=(38184, 55, 55), dtype=float32, chunksize=(7656, 55, 55), chunktype=numpy.ndarray> Coordinates: * longitude (longitude) float32 -4.75 -4.5 -4.25 -4.0 ... 8.0 8.25 8.5 8.75 * latitude (latitude) float32 53.75 53.5 53.25 53.0 ... 40.75 40.5 40.25 * time (time) datetime64[ns] 2017-01-01 ... 2021-11-29T23:00:00 Attributes: units: (0 - 1) long_name: Total cloud cover standard_name: cloud_area_fraction
It is possible to use the dimensional nature of xarray to perform operations on the data. For example, using temporal labeling we can calculate the mean TCC for each timestamp and analyze the corresponding time series to identify patterns and other behaviors of interest. Since we have very long time series, the use of interactive charts allows us to explore different time scales and move throughout the entire time window. You can check on the create_timeseries
function to see how it works. Basically we pass the time information as x axis and the mean value over our AOI as y axis.
coffee.create_timeseries(x=list(data.time.values.astype('str')), y=list(data.mean(axis=(1,2)).values))
Using the visualize_colormesh
function, we can select a specific moment of interest and plot it on a reference map to analyze the spatial distribution of clouds in that particular snapshot.
coffee.visualize_pcolormesh(data_array=data.sel(time='2017-01-01T02:00:00.000000000'),
longitude=data.longitude,
latitude=data.latitude,
projection=ccrs.PlateCarree(),
color_scale='Blues',
unit='%',
long_name= 'Total Cloud Cover ' + str(data.time.data[1]),
vmin=0,
vmax=1,
lonmin=data.longitude.min().data,
lonmax=data.longitude.max().data,
latmin=data.latitude.min().data,
latmax=data.latitude.max().data,
set_global=False)
(<Figure size 1440x720 with 2 Axes>, <GeoAxesSubplot:title={'center':'Total Cloud Cover 2017-01-01T01:00:00.000000000'}>)
Going even further, we can use the temporal dimension of the data to create animations on temporal subsets. In this case, the data can be thought of as a video where every hour is a new frame. Using holoviz library we can browse along each frame. We choose the first week of the year. Longer time periods can be selected but it will demand more time and memory.
subset = data.sel(time=slice('2017-01-01T01:00:00.000000000', '2017-01-07T23:00:00.000000000'))
# Specify the dataset, its coordinates and requested variable
dataset = gv.Dataset(subset.load(), ['longitude', 'latitude', 'time'], 'mag', crs=crs.PlateCarree())
images = dataset.to(gv.Image,dynamic=False)
# Slider location
hv.output(widget_location='bottom')
# Create stack of images grouped by time
images.opts(active_tools=['wheel_zoom', 'pan'], cmap="blues",colorbar=True, width=500, height=500, clim=(0.1,1.0))
Now we want to transform the data in such a way that a deep learning model can handle it. For this we first convert the xarray into a numpy array, without labels where only the TCC values are stored. This matrix has 3 dimensions: first the temporal one, in this case 38184, each of these corresponds to a different time, the remaining two dimensions correspond to the latitude-longitude grid.
tcc_array=data.to_numpy()
tcc_array.shape
(38184, 55, 55)
This means that, as it is, we have a single sequence, or video, of n hours or frames. For our case study, we will transform this in such a way that each day is just one video consisting of 24 frames. For this we reshape the array, where the new dimensions are: number of dates, 24 hours and two dimensions for the geographic grid. Also. An added dimension is created for the grayscale nature of the images.
samples=int(len(data.time)/24)
dataset=tcc_array.reshape([samples,24,55,55])
dataset = numpy.expand_dims(dataset, axis=-1)
dataset.shape
(1591, 24, 55, 55, 1)
From the new reshaped dataset, select data to train and evaluate the model by randomly selecting 90% of the sequences to train and 10% to evaluate.
# Split into train and validation sets using indexing to optimize memory.
indexes = numpy.arange(dataset.shape[0])
numpy.random.shuffle(indexes)
train_index = indexes[: int(0.9 * dataset.shape[0])]
val_index = indexes[int(0.9 * dataset.shape[0]) :]
train_dataset = dataset[train_index]
val_dataset = dataset[val_index]
The model to be trained will predict the next frame of the video based on the previous one. If we consider $Y_t$ the frame to be predicted, we need a predictor $X_{t-1}$. This is defined in the create_shifted_frames function, which is then used to create $X_{train}$, $Y_{train}$, $X_{validation}$ and $Y_{validation}$ datasets.
def create_shifted_frames(data):
x = data[:, 0 : data.shape[1] - 1, :, :]
y = data[:, 1 : data.shape[1], :, :]
return x, y
x_train, y_train = create_shifted_frames(train_dataset)
x_val, y_val = create_shifted_frames(val_dataset)
The following code block randomly selects one of the training samples and plots its 24 frames.
# Construct a figure on which we will visualize the images.
fig, axes = plt.subplots(4,6 , figsize=(12, 10))
# Plot each of the sequential images for one random data example.
data_choice = numpy.random.choice(range(len(train_dataset)), size=1)[0]
for idx, ax in enumerate(axes.flat):
ax.imshow(numpy.squeeze(train_dataset[data_choice][idx]), cmap="GnBu")
ax.set_title(f"Frame {idx + 1}")
ax.axis("off")
# Print information and display the figure.
print(f"Displaying frames for example {data_choice}.")
plt.show()
Displaying frames for example 368.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import save_model, load_model
import io
import imageio
from IPython.display import display
from ipywidgets import widgets, HBox
The basis of the model for this exercise is the ConvLSTM2D architecture (Shi et al. 2015) in keras implementation (https://keras.io/examples/vision/conv_lstm/) . This architecture is an extension of the type of short-term memory (LSTM) networks that are commonly used to process data streams, as it retains feedback connections. The main difference in the ConvLSTM architectures is the way operations are performed on the spatial domain of the data, it is done through 2D convolutions instead of simple matrix operations. Additionally, Batch normalization is used between layers to stabilize and speed up training. Finally, a Conv3D layer is introduced, which performs convolutions on volumes, to take into account the space-time nature of the outputs.
inp = layers.Input(shape=(None, *x_train.shape[2:]))
x = layers.ConvLSTM2D(
filters=32,
kernel_size=(5, 5),
padding="same",
return_sequences=True,
activation="relu",
)(inp)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
filters=32,
kernel_size=(3, 3),
padding="same",
return_sequences=True,
activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
filters=32,
kernel_size=(1, 1),
padding="same",
return_sequences=True,
activation="relu",
)(x)
x = layers.Conv3D(
filters=1, kernel_size=(3, 3, 3), activation="sigmoid", padding="same"
)(x)
model = keras.models.Model(inp, x)
model.compile(loss=keras.losses.MeanSquaredError(), optimizer=keras.optimizers.Adam(learning_rate=0.000001))
2022-09-04 10:28:02.360565: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Before training the model, we define a stop early callback that will end training if the loss of model validations reaches a minimum value and stabilizes for 10 epochs. Training a Deep Learning model is always time and hardware intensive. In this case we provide you with the necessary lines to train a new model but, to avoid this step, a pre-trained model is also given. The next block will train a model just if there is no model inside model folder.
if os.path.exists('model/Model.h5'):
print('Model already trained')
else:
# Define some callbacks to improve training.
early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
# Define modifiable training hyperparameters.
epochs = 100
batch_size = 1
# Fit the model to the training data.
model.fit(x_train,
y_train,
epochs=epochs,
validation_data=(x_val, y_val),
callbacks=[early_stopping],
)
Model already trained
After the training is complete, the model is saved to an h5 file that stores weights, architecture, and parameters in a single file.
if os.path.exists('model/Model.h5'):
print('Model already saved')
else:
save_model(model,workdir+'/model.h5')
Model already saved
The following code block randomly selects a sample from the validation dataset and predicts 12 final truth frames with the predicted ones.
model = load_model('model/Model.h5', compile = True)
model.summary()
Model: "model_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) [(None, None, 55, 55, 1) 0 ] conv_lstm2d_3 (ConvLSTM2D) (None, None, 55, 55, 32) 105728 batch_normalization_2 (Batc (None, None, 55, 55, 32) 128 hNormalization) conv_lstm2d_4 (ConvLSTM2D) (None, None, 55, 55, 32) 73856 batch_normalization_3 (Batc (None, None, 55, 55, 32) 128 hNormalization) conv_lstm2d_5 (ConvLSTM2D) (None, None, 55, 55, 32) 8320 conv3d_1 (Conv3D) (None, None, 55, 55, 1) 865 ================================================================= Total params: 189,025 Trainable params: 188,897 Non-trainable params: 128 _________________________________________________________________
# Select a few random examples from the dataset.
examples = val_dataset[numpy.random.choice(range(len(val_dataset)), size=1)]
# Iterate over the examples and predict the frames.
predicted_videos = []
for example in examples:
# Pick the first/last ten frames from the example.
frames = example[:12, ...]
original_frames = example[12:, ...]
new_predictions = numpy.zeros(shape=(12, *frames[0].shape))
# Predict a new set of 10 frames.
for i in range(12):
# Extract the model's prediction and post-process it.
frames = example[: 12 + i + 1, ...]
new_prediction = model.predict(numpy.expand_dims(frames, axis=0))
new_prediction = numpy.squeeze(new_prediction, axis=0)
predicted_frame = numpy.expand_dims(new_prediction[-1, ...], axis=0)
# Extend the set of prediction frames.
new_predictions[i] = predicted_frame
# Create and save GIFs for each of the ground truth/prediction images.
for frame_set in [original_frames, new_predictions]:
# Construct a GIF from the selected video frames.
current_frames = numpy.squeeze(frame_set)
current_frames = current_frames[..., numpy.newaxis] * numpy.ones(3)
current_frames = (current_frames * 255).astype(numpy.uint8)
current_frames = list(current_frames)
# Construct a GIF from the frames.
with io.BytesIO() as gif:
imageio.mimsave(gif, current_frames, "GIF", fps=5)
predicted_videos.append(gif.getvalue())
# Display the videos.
print(" Truth\tPrediction")
for i in range(0, len(predicted_videos), 2):
# Construct and display an `HBox` with the ground truth and prediction.
box = HBox(
[
widgets.Image(value=predicted_videos[i],width=300,height=200),
widgets.Image(value=predicted_videos[i + 1],width=300,height=200),
]
)
display(box)
1/1 [==============================] - 2s 2s/step 1/1 [==============================] - 1s 685ms/step 1/1 [==============================] - 0s 182ms/step 1/1 [==============================] - 0s 184ms/step 1/1 [==============================] - 0s 223ms/step 1/1 [==============================] - 0s 199ms/step 1/1 [==============================] - 0s 250ms/step 1/1 [==============================] - 0s 255ms/step 1/1 [==============================] - 0s 227ms/step 1/1 [==============================] - 0s 272ms/step 1/1 [==============================] - 0s 300ms/step 1/1 [==============================] - 0s 280ms/step Truth Prediction
HBox(children=(Image(value=b'GIF89a7\x007\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…
Finally we use a date outside the original timelapse and make a prediction on it. You can download any day you want for this. We provide you with a random date outside the timestamp.
validation_date=xr.open_dataset('data/ERA5/validation.nc')
validation_date=data.to_numpy()
nvalidation_date = numpy.expand_dims(validation_date, axis=-1)
# Select a few random examples from the dataset.
examples = [nvalidation_date]
# Iterate over the examples and predict the frames.
predicted_videos = []
for example in examples:
# Pick the first/last ten frames from the example.
frames = example[:12, ...]
original_frames = example[12:, ...]
new_predictions = numpy.zeros(shape=(12, *frames[0].shape))
# Predict a new set of 10 frames.
for i in range(12):
# Extract the model's prediction and post-process it.
frames = example[: 12 + i + 1, ...]
new_prediction = model.predict(numpy.expand_dims(frames, axis=0))
new_prediction = numpy.squeeze(new_prediction, axis=0)
predicted_frame = numpy.expand_dims(new_prediction[-1, ...], axis=0)
# Extend the set of prediction frames.
new_predictions[i] = predicted_frame
# Create and save GIFs for each of the ground truth/prediction images.
for frame_set in [original_frames, new_predictions]:
# Construct a GIF from the selected video frames.
current_frames = numpy.squeeze(frame_set)
current_frames = current_frames[..., numpy.newaxis] * numpy.ones(3)
current_frames = (current_frames * 255).astype(numpy.uint8)
current_frames = list(current_frames)
# Construct a GIF from the frames.
with io.BytesIO() as gif:
imageio.mimsave(gif, current_frames, "GIF", fps=5)
predicted_videos.append(gif.getvalue())
# Display the videos.
print(" Truth\tPrediction")
for i in range(0, len(predicted_videos), 2):
# Construct and display an `HBox` with the ground truth and prediction.
box = HBox(
[
widgets.Image(value=predicted_videos[i],width=300,height=200),
widgets.Image(value=predicted_videos[i + 1],width=300,height=200),
]
)
display(box)
1/1 [==============================] - 0s 120ms/step 1/1 [==============================] - 0s 133ms/step 1/1 [==============================] - 0s 97ms/step 1/1 [==============================] - 0s 125ms/step 1/1 [==============================] - 0s 150ms/step 1/1 [==============================] - 0s 134ms/step 1/1 [==============================] - 0s 157ms/step 1/1 [==============================] - 0s 127ms/step 1/1 [==============================] - 0s 142ms/step 1/1 [==============================] - 0s 158ms/step 1/1 [==============================] - 0s 148ms/step 1/1 [==============================] - 0s 159ms/step Truth Prediction
HBox(children=(Image(value=b'GIF89a7\x007\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…
With tre tools we have give you, now you can start exploring and expanding. Here we propose some challenges that can be done with some changes to this notebook.