!module listCurrently Loaded Modulefiles:
1) openmpi/4.1.5 2) singularity 3) NCI-ai-ml/25.07 4) pbs
1gpuvolta1gpu<xy01>gdata/dk92+scratch/<xy01>/g/data/dk92/apps/Modules/modulefiles/NCI-ai-ml/25.0710GBcopy the tested notebook from any/all of the following path to your own working directory. If your working directory is different from “/scratch/
/g/data/dk92/data/aardvark-weather/sample_data./g/data/dk92/data/aardvark-weather/training_data since the default data loader doesn’t read directly from the netcdf files in the training datasets. The transformed data requires approximately 200GiB disk space. If you need help to do the data transformation, please get in touch through NCI help desk.g/data/dk92/data/aardvark-weather/trained_model.Aardvark Weather is an end-to-end data-driven model for weather prediction. Unlike most of its data-driven forecasting model peers, it directly maps past observations to future forecasts. The architecture consists of three components: an encoder, a processor, and a decoder. The encoder ingests raw observations to estimate the gridded initial state of the atmosphere. The processor then advances the estimated atmospheric state in time. Finally, the decoder generates the prediction at the point of interest based on the gridded state forecast. Each component was initially trained independently. The encoder was trained on level 1B or 1C satellite data (ASCAT, AMSU-A & B, HIRS, IASI, GRIDSAT) and in-situ records (HadISD, ICOADS, IGRA). The processor was trained on 1.5° regridded ERA5 dataset utilising surface variables (t2m, u10, v10, mslp) and pressure level variables (q, z, t, u, v) at level 200, 500, 700, and 850 hPa. The decoder was trained on HadISD station data using near-surface temperature (t2m) and wind speed (ws). Once these blocks were individually trained, they were chained in sequence and fine-tuning into the end-to-end model.
In the paper [1] , the authors systematically benchmark Aardvark Weather and demonstrate its potential to replace the full numerical weather prediction (NWP) pipeline.
For global gridded forecasts, Aardvark achieves latitude-weighed RMSE comparable to the Integrated Forecasting System (IFS) in its high resolution configuration (HRES) from the European Centre for Medium-Range Weather Forecasts (ECMWF) and Global Forecast System (GFS) from the National Centers for Environmental Prediction, across variables including t2m, u10, v10, mslp, t850, u700, and q700 at lead time up to 10 days. Spatially, the model successfully reproduces the large-scale atmospheric state features in both the mid-latitude and the tropics.
For station forecasting, Aardvark delivers global mean absolute error (MAE) on par with the station-corrected HRES for temperatures predictions up to 10 days and for wind speed up to 6 days. Notably, in resource-limited areas such as West Africa and the Pacific, Aardvark consistently out-performed the station-corrected IFS-HRES across all lead times.
At NCI, we configured the NCI-ai-ml environment compatible to support Aardvark Weather [2], enabling researchers to explore its datasets, and building its training and inference pipelines on Gadi. To assist users, four Jupyter notebooks have been developed, showcasing the satellite and synoptic observation datasets[3], model training workflow, and forecast evaluation procedures.
Reference:
1. Allen, A., Markou, S., Tebbutt, W. et al. (2025). End-to-end data-driven weather prediction. Nature, 641, 1172–1179. https://doi.org/10.1038/s41586-025-08897-0
Allen, A. (2025). aardvark-weather-public [Software]. GitHub. https://github.com/anna-allen/aardvark-weather-public
Aardvark Weather Observational Dataset, av555/aardvark-weather, Hugging Face, doi:10.57967/hf/4274, https://huggingface.co/datasets/av555/aardvark-weather”
Notes from NCI: This notebook is originally from aardvark-weather-public repository. We adapt it to facilitate NCI users to run it on Gadi. The visualization of GRIDSAT is added for the completeness. It uses the netcdf file from the training datasets provided by the author in the huggingface instance av555/aardvark-weather.
This notebook provides an example of the data utilised to generate a forecast using Aardvark Weather. We explore a single timeslice containing all the observations required to generate a forecast. This sample data is the output of the loader WeatherDatasetE2E in ../aardvark/loaders.py.
Currently Loaded Modulefiles:
1) openmpi/4.1.5 2) singularity 3) NCI-ai-ml/25.07 4) pbs
Open a sample of data required to produce a forecast.
['sat_x_current',
'icoads_x_current',
'igra_x_current',
'amsua_x_current',
'amsub_x_current',
'iasi_x_current',
'ascat_x_current',
'hirs_x_current',
'era5_x_current']
Multiple different datasets are utilised as input to create a forecast, each with multiple channels including observations and metadata. Example channels for each of these are plotted below. The plot_channel variable in each cell can be adjusted to visualise different channels.
First visualise the satellite data from HIRS, AMSU-A, AMSU-B, IASI and ASCAT
fig = plt.figure()
plot_channel = 11
p = plt.contourf(
lon,
lat,
data["assimilation"]["hirs_current"][0,...,plot_channel].cpu().T,
levels=100,
cmap="magma")
cbar = fig.colorbar(p)
cbar.set_label('Normalised radiance')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"HIRS channel {plot_channel}")
plt.show()
fig = plt.figure()
plot_channel = 8
p = plt.contourf(
lon,
lat[:-1],
data["assimilation"]["amsua_current"][0,...,plot_channel].cpu(),
levels=100,
cmap="magma")
cbar = fig.colorbar(p)
cbar.set_label('Normalised radiance')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"AMSU-A channel {plot_channel}")
plt.show()
fig = plt.figure()
plot_channel = 10
p = plt.contourf(
lon,
lat,
data["assimilation"]["amsub_current"][0,...,plot_channel].T.cpu(),
levels=100,
cmap="magma")
cbar = fig.colorbar(p)
cbar.set_label('Normalised radiance')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"AMSU-A channel {plot_channel}")
plt.show()
fig = plt.figure()
plot_channel = 10
p = plt.contourf(
lon,
lat,
data["assimilation"]["iasi_current"][0,...,plot_channel].T.cpu(),
levels=100,
cmap="magma")
cbar = fig.colorbar(p)
cbar.set_label('Normalised radiance')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"IASI channel {plot_channel}")
plt.show()
print(f"There are {data['assimilation']['ascat_current'].shape[-1]} channels in this dataset.\nPlotting channel 5.\nModify plot_channel to any integer between 0 and 16 to inspect other channels.")
fig = plt.figure()
plot_channel = 5
p = plt.contourf(
lon,
lat,
data["assimilation"]["ascat_current"][0,...,plot_channel].T.cpu(),
levels=100,
cmap="magma")
cbar = fig.colorbar(p)
cbar.set_label('Normalised radiance')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"ASCAT channel {plot_channel}")
plt.show()There are 17 channels in this dataset.
Plotting channel 5.
Modify plot_channel to any integer between 0 and 16 to inspect other channels.

NOTE from NCI: Two channels of GRIDSAT dataset are also used in training of Aardvark Weather.
<xarray.Dataset> Size: 4GB
Dimensions: (time: 4746, latitude: 200, longitude: 514)
Coordinates:
* time (time) datetime64[ns] 38kB 2007-01-02 ... 2019-12-30
* longitude (longitude) float32 2kB 0.215 0.915 1.615 ... 358.8 359.5
* latitude (latitude) float32 800B -69.68 -68.99 -68.28 ... 68.92 69.61
Data variables:
gridsat_6p7 (time, latitude, longitude) float32 2GB ...
gridsat_10p3 (time, latitude, longitude) float32 2GB ...array(['2007-01-02T00:00:00.000000000', '2007-01-03T00:00:00.000000000',
'2007-01-04T00:00:00.000000000', ..., '2019-12-28T00:00:00.000000000',
'2019-12-29T00:00:00.000000000', '2019-12-30T00:00:00.000000000'],
dtype='datetime64[ns]')array([2.149963e-01, 9.150085e-01, 1.614990e+00, ..., 3.581150e+02,
3.588150e+02, 3.595150e+02], dtype=float32)array([-69.685 , -68.985 , -68.284996, -67.58501 , -66.884995, -66.185 ,
-65.485 , -64.784996, -64.08501 , -63.385002, -62.684998, -61.984997,
-61.284996, -60.585003, -59.885002, -59.184998, -58.484997, -57.784996,
-57.085003, -56.385002, -55.685005, -54.984997, -54.284996, -53.585 ,
-52.885002, -52.185005, -51.484997, -50.785 , -50.085 , -49.385002,
-48.685 , -47.984997, -47.285 , -46.585 , -45.885002, -45.185 ,
-44.484997, -43.785 , -43.085 , -42.385002, -41.685 , -40.984997,
-40.285 , -39.585 , -38.885002, -38.185 , -37.484997, -36.785 ,
-36.085 , -35.385002, -34.684998, -33.984997, -33.285 , -32.585 ,
-31.885 , -31.185001, -30.484997, -29.785 , -29.084997, -28.385 ,
-27.685001, -26.985 , -26.285 , -25.585 , -24.885 , -24.185001,
-23.485 , -22.785 , -22.085 , -21.385 , -20.685001, -19.985 ,
-19.285 , -18.585 , -17.885 , -17.185001, -16.484999, -15.785001,
-15.084999, -14.385 , -13.684999, -12.985001, -12.285 , -11.584999,
-10.885 , -10.184999, -9.485 , -8.785 , -8.084999, -7.385 ,
-6.685 , -5.984999, -5.285 , -4.585001, -3.884998, -3.185001,
-2.484998, -1.785 , -1.085001, -0.384998, 0.314999, 1.015002,
1.715 , 2.414999, 3.115002, 3.814999, 4.515002, 5.215 ,
5.914999, 6.615002, 7.314999, 8.015001, 8.715 , 9.414999,
10.115002, 10.814999, 11.515001, 12.215 , 12.914999, 13.615003,
14.315 , 15.015001, 15.714999, 16.414999, 17.115002, 17.814999,
18.515001, 19.215 , 19.914999, 20.615002, 21.314999, 22.015001,
22.715 , 23.414999, 24.115002, 24.814999, 25.515001, 26.215 ,
26.914999, 27.615002, 28.314999, 29.015003, 29.715 , 30.414999,
31.115002, 31.814999, 32.515003, 33.215 , 33.915 , 34.615 ,
35.315 , 36.015003, 36.715 , 37.415 , 38.115 , 38.815 ,
39.515003, 40.215 , 40.915 , 41.615 , 42.315 , 43.015003,
43.715 , 44.415 , 45.115 , 45.815 , 46.515003, 47.215 ,
47.915 , 48.615 , 49.315 , 50.015003, 50.715 , 51.415 ,
52.115 , 52.814995, 53.515003, 54.215004, 54.915 , 55.615 ,
56.314995, 57.015003, 57.715004, 58.415 , 59.114998, 59.815002,
60.515003, 61.215004, 61.915 , 62.614998, 63.315002, 64.015 ,
64.715004, 65.415 , 66.115 , 66.815 , 67.515 , 68.215004,
68.915 , 69.615 ], dtype=float32)[487888800 values with dtype=float32]
[487888800 values with dtype=float32]
PandasIndex(DatetimeIndex(['2007-01-02', '2007-01-03', '2007-01-04', '2007-01-05',
'2007-01-06', '2007-01-07', '2007-01-08', '2007-01-09',
'2007-01-10', '2007-01-11',
...
'2019-12-21', '2019-12-22', '2019-12-23', '2019-12-24',
'2019-12-25', '2019-12-26', '2019-12-27', '2019-12-28',
'2019-12-29', '2019-12-30'],
dtype='datetime64[ns]', name='time', length=4746, freq=None))PandasIndex(Index([ 0.214996337890625, 0.915008544921875, 1.614990234375,
2.31500244140625, 3.0150146484375, 3.714996337890625,
4.415008544921875, 5.114990234375, 5.81500244140625,
6.5150146484375,
...
353.2149963378906, 353.9150085449219, 354.614990234375,
355.31500244140625, 356.0150146484375, 356.7149963378906,
357.4150085449219, 358.114990234375, 358.81500244140625,
359.5150146484375],
dtype='float32', name='longitude', length=514))PandasIndex(Index([-69.68499755859375, -68.98500061035156, -68.28499603271484,
-67.58500671386719, -66.88499450683594, -66.18499755859375,
-65.48500061035156, -64.78499603271484, -64.08500671386719,
-63.38500213623047,
...
63.31500244140625, 64.01499938964844, 64.71500396728516,
65.41500091552734, 66.11499786376953, 66.81500244140625,
67.51499938964844, 68.21500396728516, 68.91500091552734,
69.61499786376953],
dtype='float32', name='latitude', length=200))lat_sat =nc.latitude.values
lon_sat =nc.longitude.values
data_sat= nc["gridsat_6p7"].values
print(f"There are 2 channels in this dataset.\nPlotting channel `gridsat_6p7`.\nModify the varname used in data_sat to inspect the other channel `gridsat_10p3`.")
fig = plt.figure()
p = plt.contourf(
lon_sat,
lat_sat,
data_sat[0,:],
levels=100,
cmap="magma")
cbar = fig.colorbar(p)
cbar.set_label('Normalised radiance')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"GRIDSAT channel 6p7")
plt.show()There are 2 channels in this dataset.
Plotting channel `gridsat_6p7`.
Modify the varname used in data_sat to inspect the other channel `gridsat_10p3`.

We next visualise the SYNOPS data from land stations, marine platforms and radiosonde profiles
fig = plt.figure()
plot_channel = 0
p = plt.scatter(
data["assimilation"]["x_context_hadisd_current"][plot_channel][0,0,:].cpu(),
data["assimilation"]["x_context_hadisd_current"][plot_channel][0,1,:].cpu(),
c = data["assimilation"]["y_context_hadisd_current"][plot_channel][0,:].cpu(),
cmap="magma")
cbar = fig.colorbar(p)
cbar.set_label('Normalised value')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"HadISD channel {plot_channel}")
plt.show()
plot_channel = 1
fig = plt.figure()
p = plt.scatter(
data["assimilation"]["icoads_x_current"][0][0,:].cpu(),
data["assimilation"]["icoads_x_current"][1][0,:].cpu(),
c = data["assimilation"]["icoads_current"][0,plot_channel,:].cpu(),
cmap="magma")
cbar = fig.colorbar(p)
cbar.set_label('Normalised value')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"ICOADS channel {plot_channel}")
plt.show()
plot_channel = 1
fig = plt.figure()
p = plt.scatter(
data["assimilation"]["igra_x_current"][0][0,:].cpu(),
data["assimilation"]["igra_x_current"][1][0,:].cpu(),
c = data["assimilation"]["igra_current"][0,plot_channel,:].cpu(),
cmap="magma")
cbar = fig.colorbar(p)
cbar.set_label('Normalised value')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"IGRA channel {plot_channel}")
plt.show()
Notes from NCI: This notebook is originally from aardvark-weather-public repository. We adapt it to facilitate NCI users to run it on Gadi. We fixed several issues in the source code to allow it running under torch 2.7.0.
Here we load the trained Aardvark Weather model and produce a global and station forecast using the sample data.
Currently Loaded Modulefiles:
1) openmpi/4.1.5 2) singularity 3) NCI-ai-ml/25.07 4) pbs
import sys
wdir = "/g/data/dk92/notebooks/examples-aiml/aardvark-weather/aardvark-weather-public"
ddir = "/g/data/dk92/data/aardvark-weather"
sys.path.append(f"{wdir}/aardvark")
import pickle
import numpy as np
from e2e_model import *
import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams["mathtext.fontset"] = "stix"
matplotlib.rcParams["font.family"] = "STIXGeneral"Check a GPU is available
GPU is available.
GPU name: Tesla V100-PCIE-32GB
Load the sample data (for a detailed analysis and visualisation of the contents of this dataset see data_demo.ipynb)
Load the model to generate predictions at one day leadtime. First select which varaible to generate station forecasts for.
local_forecast_var = "tas" # Model weights included for windspeed (ws) and 2tm (tas)
model = ConvCNPWeatherE2E(
device=device,
lead_time=1,
se_model_path=f"{ddir}/trained_model/encoder",
forecast_model_path=f"{ddir}/trained_model/processor",
sf_model_path=f"{ddir}/trained_model/decoder/{local_forecast_var}/",
return_gridded=True,
aux_data_path=f"{ddir}/sample_data/",
)Run the model to generate a forecast the sample data. This outputs the station forecast, gridded forevast and initial state
First look at the gridded forecasts. Visualise several variables
Select the variable and data to plot. To plot another variable simply change the variable argument below.
Plot the initial state
fig = plt.figure(figsize=(10, 5))
plot_channel = 10
p = plt.contourf(
lon,
lat,
initial_state_var,
levels=100,
vmax=colorscale_mag,
vmin=-colorscale_mag,
cmap="RdBu_r",
)
cbar = fig.colorbar(p)
# cbar.set_label('(m/s)')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"{variable} initial state")
plt.show()
Plot the prediction at one day leadtime
fig = plt.figure(figsize=(10, 5))
plot_channel = 10
p = plt.contourf(
lon,
lat,
global_forecast_var,
levels=100,
vmax=colorscale_mag,
vmin=-colorscale_mag,
cmap="RdBu_r",
)
cbar = fig.colorbar(p)
# cbar.set_label('(m/s)')
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"{variable} forecast")
plt.show()
The model also returns the station forecasts for T2M
/jobfs/151106967.gadi-pbs/ipykernel_405486/3168040627.py:2: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
station_forecast.detach().cpu() * std + mean
fig = plt.figure(figsize=(10, 5))
plot_channel = 10
p = plt.scatter(
data["downscaling"]["x_target"][0, 0, :].detach().cpu() * STATION_LON_LAT_SF,
data["downscaling"]["x_target"][0, 1, :].detach().cpu() * STATION_LON_LAT_SF,
c=station_forecast[0, :],
vmax=30,
vmin=-30,
cmap="RdBu_r",
)
cbar = fig.colorbar(p)
cbar.set_label("(C)")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"{local_forecast_var}")
plt.show()
Notes from NCI: This notebook is originally from aardvark-weather-public repository. We adapt it to facilitate NCI users to run it on Gadi. We fixed several issues in the source code to allow it running under torch 2.7.0. We also want to point out that it is not the purpose of this notebook to demonstrate how to finetune this end-to-end model. Instead, it shows how to use the fine-tuned end-to-end model to predict.
Here we load the trained Aardvark Weather model and produce end-to-end finetuned forecasts at one day lead time for temperature and windspeed.
Currently Loaded Modulefiles:
1) openmpi/4.1.5 2) singularity 3) NCI-ai-ml/25.07 4) pbs
import sys
wdir = "/g/data/dk92/notebooks/examples-aiml/aardvark-weather/aardvark-weather-public"
ddir = "/g/data/dk92/data/aardvark-weather"
sys.path.append(f"{wdir}/aardvark")
import numpy as np
import pickle
from e2e_model import *
from matplotlib import pyplot as plt
import matplotlib
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'Check a GPU is available
GPU is available.
GPU name: Tesla V100-PCIE-32GB
Load the sample data (for a detailed analysis and visualisation of the contents of this dataset see data_demo.ipynb)
Load the end to end model. First select which varaible to generate station forecasts for.
local_forecast_var = "tas" # Model weights included for windspeed (ws) and 2tm (tas)
model = ConvCNPWeatherE2E(
device="cuda",
lead_time=1,
se_model_path=f"{ddir}/trained_model/encoder",
forecast_model_path=f"{ddir}/trained_model/processor",
sf_model_path=f"{ddir}/trained_model/decoder/{local_forecast_var}/",
return_gridded=True,
aux_data_path=f"{ddir}/sample_data/",
)Load the trained weights
weights_path = f"{ddir}/trained_model/e2e_finetuned/{local_forecast_var}/"
best_epoch = np.argmin(np.load(weights_path+"losses_0.npy"))
state_dict = torch.load(
f"{weights_path}/epoch_{best_epoch}", map_location="cuda",weights_only=False
)["model_state_dict"]
state_dict = {k[7:]: v for k, v in zip(state_dict.keys(), state_dict.values())}
model.load_state_dict(state_dict)
model = model.to("cuda")Plot the station forecasts
/jobfs/151106967.gadi-pbs/ipykernel_405681/444701172.py:2: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
station_forecast.detach().cpu() * std + mean
fig = plt.figure(figsize=(10, 5))
plot_channel = 10
p = plt.scatter(
data["downscaling"]["x_target"][0, 0, :].detach().cpu() * STATION_LON_LAT_SF,
data["downscaling"]["x_target"][0, 1, :].detach().cpu() * STATION_LON_LAT_SF,
c=station_forecast_unnorm[0, :],
vmax=30,
vmin=-30,
cmap="RdBu_r",
)
cbar = fig.colorbar(p)
cbar.set_label("(C)")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title(f"{local_forecast_var}")
plt.show()
This notebook is largely an adaption from the script used for training different blocks in the aardvark-weather model. We tailored it to specificly train the data assimilation block using the data in the year of 2007 from the training data released by the author in the huggingface instance av555/aardvark-weather
Please note, this is only a demonstration of the protocal used in training the data assimilation block of the aardvark-weather model.
In order to run this notebook, please rewrite the training data in the mmap format required by the original dataloader and pass the path to the WeatherDatasetAssimilation class to replace aardvark/datasets/. If you need help to generate the mmap files, please get in touch.
Currently Loaded Modulefiles:
1) openmpi/4.1.5 2) singularity 3) NCI-ai-ml/25.07 4) pbs
import os, sys
import torch
#from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
wdir="/g/data/dk92/notebooks/examples-aiml/aardvark-weather/aardvark-weather-public/aardvark"
sys.path.append(wdir)
from loader import *
from models import *
from loss_functions import WeightedRmseLoss
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
torch.cuda.set_device(0)# set up dataset
train_dataset = WeatherDatasetAssimilation(
device=device,
dpath="<replace with the path to the training data>",
hadisd_mode="train",
start_date="2007-01-02",
end_date="2007-12-31",
lead_time=0,
era5_mode="4u",
res=1,
var_start=0,
var_end=24,
diff=False,
two_frames=False,
)
print(len(train_dataset))Loading IGRA
Loading AMSU-A
Loading AMSU-B
Loading ICOADS
Loading IASI
Loading GEO
Loading HADISD
Loading ASCAT
Loading ERA5
1451
/opt/conda/envs/mlenv/lib/python3.10/site-packages/torch/functional.py:554: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /home/conda/feedstock_root/build_artifacts/libtorch_1746251337391/work/aten/src/ATen/native/TensorShape.cpp:4314.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
# dive into the data structure used in the training sample.
task = train_dataset[0]
for k in task.keys():
v = task[k]
try:
print(f"{k}: {v.shape}")
except:
if "_x_" in k:
print(f"{k}: lat={v[1].shape}, lon={v[0].shape}")
elif "hadisd" in k:
print(f"{k}: ch1={v[0].shape}, ch2={v[1].shape}, ch3={v[2].shape}, ch4={v[3].shape}, ch5={v[4].shape}")x_context_hadisd_current: ch1=torch.Size([2, 9549]), ch2=torch.Size([2, 9447]), ch3=torch.Size([2, 8846]), ch4=torch.Size([2, 9551]), ch5=torch.Size([2, 9551])
y_context_hadisd_current: ch1=torch.Size([9549]), ch2=torch.Size([9447]), ch3=torch.Size([8846]), ch4=torch.Size([9551]), ch5=torch.Size([9551])
climatology_current: torch.Size([24, 240, 121])
sat_x_current: lat=torch.Size([200]), lon=torch.Size([514])
sat_current: torch.Size([2, 514, 200])
icoads_x_current: lat=torch.Size([12000]), lon=torch.Size([12000])
icoads_current: torch.Size([5, 12000])
igra_x_current: lat=torch.Size([1375]), lon=torch.Size([1375])
igra_current: torch.Size([24, 1375])
amsua_current: torch.Size([180, 360, 13])
amsua_x_current: lat=torch.Size([180]), lon=torch.Size([360])
amsub_current: torch.Size([360, 180, 12])
amsub_x_current: lat=torch.Size([180]), lon=torch.Size([360])
iasi_current: torch.Size([360, 181, 52])
iasi_x_current: lat=torch.Size([181]), lon=torch.Size([360])
ascat_current: torch.Size([360, 181, 17])
ascat_x_current: lat=torch.Size([181]), lon=torch.Size([360])
hirs_current: torch.Size([360, 180, 26])
hirs_x_current: lat=torch.Size([180]), lon=torch.Size([360])
y_target_current: torch.Size([121, 240, 24])
era5_x_current: lat=torch.Size([121]), lon=torch.Size([240])
era5_elev_current: torch.Size([7, 121, 240])
era5_lonlat_current: torch.Size([2, 240, 121])
aux_time_current: torch.Size([5])
lt: torch.Size([1])
y_target: torch.Size([121, 240, 24])
/g/data/dk92/notebooks/examples-aiml/aardvark-weather/aardvark-weather-public/aardvark/loader.py:488: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /home/conda/feedstock_root/build_artifacts/libtorch_1746251337391/work/torch/csrc/utils/tensor_numpy.cpp:203.)
return torch.from_numpy(arr).float().to(self.device)
1451
# training loop
n_epochs=3
for epoch in range(n_epochs):
acc_loss=0
model.train()
prev_step = None
for ii, task in enumerate(train_loader):
out = model(task,film_index=0)
loss = lf(task["y_target"],out,prev_step,fix_sigma=False)
loss.backward()
prev_step = out
opt.step()
opt.zero_grad()
acc_loss += loss.item()
#print(f"epoch {epoch} step {ii}: loss={loss}")
#if ii>= 8:
# break
print(f"epoch {epoch}: average training loss={acc_loss/len(train_loader)}")epoch 0: average training loss=0.07086729208491409
epoch 1: average training loss=0.058982408160714754
epoch 2: average training loss=0.05863337186925745