Skip to content

Curvigrids with xugrid from structured2d #1145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 5, 2025
341 changes: 160 additions & 181 deletions dfm_tools/xugrid_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,220 +279,197 @@ def open_partitioned_dataset(file_nc:str, decode_fillvals:bool = False, remove_e


def open_dataset_curvilinear(file_nc,
varn_lon='longitude',
varn_lat='latitude',
varn_vert_lon='vertices_longitude', #'grid_x'
varn_vert_lat='vertices_latitude', #'grid_y'
ij_dims=['i','j'], #['N','M']
convert_360to180=False,
**kwargs):
x_dim:str,
y_dim:str,
x_bounds:str,
y_bounds:str,
convert_360to180:bool = False ,
**kwargs) -> xu.UgridDataset:
"""
This is a first version of a function that creates a xugrid UgridDataset from a curvilinear dataset like CMCC. Curvilinear means in this case 2D lat/lon variables and i/j indexing. The CMCC dataset does contain vertices, which is essential for conversion to ugrid.
It also works for WAQUA files that are converted with getdata
Construct a UgridDataset from a curvilinear grid with 2D lat/lon variables
with i/j indexes/dims, including vertices variables. Works for curvilinear
datasets like CMCC and also for WAQUA files that are converted with getdata

Parameters
----------
file_nc : str or path
DESCRIPTION.
x_dim : str
The x-dimension, like lon, i or N.
y_dim : str
The y-dimension, like lat, j or M.
x_bounds : str
The variable with the x-bounds, like vertices_longitude or grid_x.
y_bounds : str
The variable with the y-bounds, like vertices_latitude or grid_y.
convert_360to180 : bool, optional
Whether to convert from a 0 to 360 degree global model to a -180 to 180
degree global model. The default is False.
**kwargs : TYPE
additional arguments are passed on to xr.open_mfdataset().

Returns
-------
uds : xu.UgridDataset
The resulting ugrid dataset.

"""
# TODO: maybe get varn_lon/varn_lat automatically with cf-xarray (https://github.com/xarray-contrib/cf-xarray)

if 'chunks' not in kwargs:
kwargs['chunks'] = {'time':1}

# data_vars='minimal' to avoid time dimension on vertices_latitude and others when opening multiple files at once
ds = xr.open_mfdataset(file_nc, data_vars="minimal", **kwargs)

print('>> getting vertices from ds: ',end='')
# data_vars='minimal' to avoid time dimension on vertices_latitude and
# others when opening multiple files at once
print('>> open_mfdataset: ',end='')
dtstart = dt.datetime.now()
vertices_longitude = ds.variables[varn_vert_lon].to_numpy()
vertices_longitude = vertices_longitude.reshape(-1,vertices_longitude.shape[-1])
vertices_latitude = ds.variables[varn_vert_lat].to_numpy()
vertices_latitude = vertices_latitude.reshape(-1,vertices_latitude.shape[-1])
ds = xr.open_mfdataset(file_nc, data_vars="minimal", **kwargs)
print(f'{(dt.datetime.now()-dtstart).total_seconds():.2f} sec')

# convert from 0to360 to -180 to 180
if convert_360to180:
vertices_longitude = (vertices_longitude+180) % 360 - 180

# face_xy = np.stack([longitude,latitude],axis=-1)
# face_coords_x, face_coords_y = face_xy.T
#a,b = np.unique(face_xy,axis=0,return_index=True) #TODO: there are non_unique face_xy values, inconvenient
face_xy_vertices = np.stack([vertices_longitude,vertices_latitude],axis=-1)
face_xy_vertices_flat = face_xy_vertices.reshape(-1,2)
uniq,inv = np.unique(face_xy_vertices_flat, axis=0, return_inverse=True)
#len(uniq) = 104926 >> amount of unique node coords
#uniq.max() = 359.9654541015625 >> node_coords_xy
#len(inv) = 422816 >> is length of face_xy_vertices.reshape(-1,2)
#inv.max() = 104925 >> node numbers
node_coords_x, node_coords_y = uniq.T

face_node_connectivity = inv.reshape(face_xy_vertices.shape[:2]) #fnc.max() = 104925

#remove all faces that have only 1 unique node (does not result in a valid grid) #TODO: not used yet except for print
fnc_all_duplicates = (face_node_connectivity.T==face_node_connectivity[:,0]).all(axis=0)

#create bool of cells with duplicate nodes (some have 1 unique node, some 3, all these are dropped) #TODO: support also triangles?
fnc_closed = np.c_[face_node_connectivity,face_node_connectivity[:,0]]
fnc_has_duplicates = (np.diff(fnc_closed,axis=1)==0).any(axis=1)

#only keep cells that have 4 unique nodes
bool_combined = ~fnc_has_duplicates
print(f'WARNING: dropping {fnc_has_duplicates.sum()} faces with duplicate nodes ({fnc_all_duplicates.sum()} with one unique node)')
face_node_connectivity = face_node_connectivity[bool_combined]

grid = xu.Ugrid2d(node_x=node_coords_x,
node_y=node_coords_y,
face_node_connectivity=face_node_connectivity,
fill_value=-1,
)

print('>> stacking ds i/j coordinates: ',end='') #fast
dtstart = dt.datetime.now()
face_dim = grid.face_dimension
# TODO: lev/time bnds are dropped, avoid this. maybe stack initial dataset since it would also simplify the rest of the function a bit
ds_stacked = ds.stack({face_dim:ij_dims}).sel({face_dim:bool_combined})
latlon_vars = [varn_lon, varn_lat, varn_vert_lon, varn_vert_lat]
ds_stacked = ds_stacked.drop_vars(ij_dims + latlon_vars + [face_dim])
print(f'{(dt.datetime.now()-dtstart).total_seconds():.2f} sec')
ds[x_bounds] = (ds[x_bounds]+180) % 360 - 180

print('>> init uds: ',end='') #long
topology = {"mesh2d":{"x":x_dim,
"y":y_dim,
"x_bounds":x_bounds,
"y_bounds":y_bounds,
},
}

print('>> convert to xugrid.UgridDataset: ',end='')
dtstart = dt.datetime.now()
uds = xu.UgridDataset(ds_stacked,grids=[grid])
uds = xu.UgridDataset.from_structured2d(ds, topology=topology)
print(f'{(dt.datetime.now()-dtstart).total_seconds():.2f} sec')

# drop 0-area cells (relevant for CMCC global datasets)
bool_zero_cell_size = uds.grid.area==0
if bool_zero_cell_size.any():
print(f"WARNING: dropping {bool_zero_cell_size.sum()} 0-sized cells from dataset")
uds = uds.isel({uds.grid.face_dimension: ~bool_zero_cell_size})

#remove faces that link to node coordinates that are nan (occurs in waqua models)
bool_faces_wnannodes = np.isnan(uds.grid.face_node_coordinates[:,:,0]).any(axis=1)
if bool_faces_wnannodes.any():
print(f'>> drop {bool_faces_wnannodes.sum()} faces with nan nodecoordinates from uds: ',end='') #long
dtstart = dt.datetime.now()
uds = uds.sel({face_dim:~bool_faces_wnannodes})
print(f'{(dt.datetime.now()-dtstart).total_seconds():.2f} sec')

return uds


def get_delft3d4_nanmask(x,y):
def delft3d4_get_nanmask(x,y):
# -999.999 in kivu and 0.0 in curvedbend, both in westernscheldt
bool_0 = (x==0) & (y==0)
bool_1 = (x==-999) & (y==-999)
bool_2 = (x==-999.999) & (y==-999.999)
bool_mask = bool_0 | bool_1 | bool_2
return bool_mask


def open_dataset_delft3d4(file_nc, **kwargs):
def delft3d4_stack_shifted_coords(da):
shift = 1
np_stacked = np.stack([
da, #ll
da.shift(MC=shift), #lr
da.shift(MC=shift, NC=shift), #ur
da.shift(NC=shift), #ul
],axis=-1)
da_stacked = xr.DataArray(np_stacked, dims=("M","N","four"))
return da_stacked


def delft3d4_convert_uv(ds):
# replace invalid values not with nan but with zero
# otherwise the spatial coverage is affected
mask_u1 = (ds.U1==-999) | (ds.U1==-999.999)
mask_v1 = (ds.V1==-999) | (ds.V1==-999.999)
u1_mn = ds.U1.where(~mask_u1, 0)
v1_mn = ds.V1.where(~mask_v1, 0)

# minus 0.5 since padding=low so corner value is representative for
# previous face according to that logic, method=nearest might be better
# (or just rename the dims)
u1_mn_cen = u1_mn.interp(MC=u1_mn.MC-0.5, method='linear')
v1_mn_cen = v1_mn.interp(NC=v1_mn.NC-0.5, method='linear')
# rename corner dims to center dims since we shifted them with 0.5
u1_mn_cen = u1_mn_cen.rename({'MC':'M'})
v1_mn_cen = v1_mn_cen.rename({'NC':'N'})
# TODO: since padding=low, just renaming the dims might even be better?
# u1_mn_cen = u1_mn.rename({'MC':'M'})
# v1_mn_cen = v1_mn.rename({'NC':'N'})
# >> could also be done with `ds = ds.swap_dims({"M":"MC","N":"NC"})`

# create combined uv mask (have to rename dimensions)
mask_u1_mn = mask_u1.rename({'MC':'M'})
mask_v1_mn = mask_v1.rename({'NC':'N'})
mask_uv1_mn = mask_u1_mn & mask_v1_mn
# drop all actual missing cells
u1_mn_cen = u1_mn_cen.where(~mask_uv1_mn)
v1_mn_cen = v1_mn_cen.where(~mask_uv1_mn)

# to avoid creating large chunks, alternative is to overwrite the vars
# with the MN-averaged vars, but it requires passing and updating of attrs
ds = ds.drop_vars(['U1','V1'])

# compute ux/uy/umag/udir
# TODO: add attrs to variables
alfas_rad = np.deg2rad(ds.ALFAS)
vel_x = u1_mn_cen*np.cos(alfas_rad) - v1_mn_cen*np.sin(alfas_rad)
vel_y = u1_mn_cen*np.sin(alfas_rad) + v1_mn_cen*np.cos(alfas_rad)
ds['ux'] = vel_x
ds['uy'] = vel_y
ds['umag'] = np.sqrt(vel_x**2 + vel_y**2)
ds['udir'] = np.rad2deg(np.arctan2(vel_y, vel_x))%360
return ds


def open_dataset_delft3d4(file_nc, **kwargs) -> xu.UgridDataset:
"""
Reads in a Delft3D4 netcdf outputfile (curvilinear/staggered) as a ugrid
(xugrid.UgridDataset) dataset. This is a Delft3D4 specific version of
dfmt.open_dataset_curvilinear().

To get Delft3D4 to write netCDF output instead of .dat files, add these
lines to your model settings file (.mdf):

- FlNcdf=#maphis#
- ncFormat=4

Parameters
----------
file_nc : str or path
DESCRIPTION.
**kwargs : TYPE
additional arguments are passed on to xr.open_mfdataset().

Returns
-------
uds : xu.UgridDataset
The resulting ugrid dataset.

"""

if 'chunks' not in kwargs:
kwargs['chunks'] = {'time':1}

ds = xr.open_dataset(file_nc, **kwargs)

# prevent grid variable that might be confused with uds.grid accessor
if 'grid' in ds.data_vars:
ds = ds.rename_vars({'grid':'grid_original'})

xcor_stacked = delft3d4_stack_shifted_coords(ds.XCOR)
ycor_stacked = delft3d4_stack_shifted_coords(ds.YCOR)
mask_xy = delft3d4_get_nanmask(xcor_stacked,ycor_stacked)
ds['xcor_stacked'] = xcor_stacked.where(~mask_xy)
ds['ycor_stacked'] = ycor_stacked.where(~mask_xy)

if ('U1' in ds.data_vars) and ('V1' in ds.data_vars):
#mask u and v separately with 0 to avoid high velocities (cannot be nan, since (nan+value)/2= nan instead of value=2
mask_u1 = (ds.U1==-999) | (ds.U1==-999.999)
mask_v1 = (ds.V1==-999) | (ds.V1==-999.999)
u1_mn = ds.U1.where(~mask_u1,0)
v1_mn = ds.V1.where(~mask_v1,0)

#create combined uv mask (have to rename dimensions)
mask_u1_mn = mask_u1.rename({'MC':'M'})
mask_v1_mn = mask_v1.rename({'NC':'N'})
mask_uv1_mn = mask_u1_mn & mask_v1_mn

#average U1/V1 values to M/N
u1_mn = (u1_mn + u1_mn.shift(MC=1))/2 #TODO: or MC=-1
u1_mn = u1_mn.rename({'MC':'M'})
u1_mn = u1_mn.where(~mask_uv1_mn,np.nan) #replace temporary zeros with nan
v1_mn = (v1_mn + v1_mn.shift(NC=1))/2 #TODO: or NC=-1
v1_mn = v1_mn.rename({'NC':'N'})
v1_mn = v1_mn.where(~mask_uv1_mn,np.nan) #replace temporary zeros with nan
ds = ds.drop_vars(['U1','V1']) #to avoid creating large chunks, alternative is to overwrite the vars with the MN-averaged vars, but it requires passing and updating of attrs

#compute ux/uy/umag/udir #TODO: add attrs to variables
alfas_rad = np.deg2rad(ds.ALFAS)
vel_x = u1_mn*np.cos(alfas_rad) - v1_mn*np.sin(alfas_rad)
vel_y = u1_mn*np.sin(alfas_rad) + v1_mn*np.cos(alfas_rad)
ds['ux'] = vel_x
ds['uy'] = vel_y
ds['umag'] = np.sqrt(vel_x**2 + vel_y**2)
ds['udir'] = np.rad2deg(np.arctan2(vel_y, vel_x))%360

mn_slice = slice(1,None)
ds = ds.isel(M=mn_slice,N=mn_slice) #cut off first values of M/N (centers), since they are fillvalues and should have different size than MC/NC (corners)

#find and set nans in XCOR/YCOR arrays
mask_xy = get_delft3d4_nanmask(ds.XCOR,ds.YCOR) #-999.999 in kivu and 0.0 in curvedbend, both in westernscheldt
ds['XCOR'] = ds.XCOR.where(~mask_xy)
ds['YCOR'] = ds.YCOR.where(~mask_xy)

#convert to ugrid
node_coords_x = ds.XCOR.to_numpy().ravel()
node_coords_y = ds.YCOR.to_numpy().ravel()
xcor_shape = ds.XCOR.shape
xcor_nvals = xcor_shape[0] * xcor_shape[1]

#remove weird outlier values in kivu model
node_coords_x[node_coords_x<-1000] = np.nan
node_coords_y[node_coords_y<-1000] = np.nan

#find nodes with nan coords
if not (np.isnan(node_coords_x) == np.isnan(node_coords_y)).all():
raise Exception('node_coords_xy do not have nans in same location')
nan_nodes_bool = np.isnan(node_coords_x)
node_coords_x = node_coords_x[~nan_nodes_bool]
node_coords_y = node_coords_y[~nan_nodes_bool]

node_idx_square = -np.ones(xcor_nvals,dtype=int)
node_idx_nonans = np.arange((~nan_nodes_bool).sum())
node_idx_square[~nan_nodes_bool] = node_idx_nonans
node_idx = node_idx_square.reshape(xcor_shape)
face_node_connectivity = np.stack([node_idx[1:,:-1].ravel(), #ll
node_idx[1:,1:].ravel(), #lr
node_idx[:-1,1:].ravel(), #ur
node_idx[:-1,:-1].ravel(), #ul
],axis=1)

keep_faces_bool = (face_node_connectivity!=-1).sum(axis=1)==4

face_node_connectivity = face_node_connectivity[keep_faces_bool]

grid = xu.Ugrid2d(node_x=node_coords_x,
node_y=node_coords_y,
face_node_connectivity=face_node_connectivity,
fill_value=-1,
)

face_dim = grid.face_dimension
ds_stacked = ds.stack({face_dim:('M','N')}).sel({face_dim:keep_faces_bool})
ds_stacked = ds_stacked.drop_vars(['M','N','mesh2d_nFaces'])
uds = xu.UgridDataset(ds_stacked,grids=[grid])

uds = uds.drop_vars(['XCOR','YCOR','grid'])
uds = uds.drop_dims(['MC','NC']) #clean up dataset by dropping corner dims (drops also variabes with U/V masks and U/V/C bedlevel)

# convert to xarray.Dataset to update/remove attrs
ds_temp = uds.ugrid.to_dataset()

# set vertical dimensions attr
# TODO: would be more convenient to do within xu.Ugrid2d(): https://github.com/Deltares/xugrid/issues/195#issuecomment-2111841390
grid_attrs = {"vertical_dimensions": ds.grid.attrs["vertical_dimensions"]}
ds_temp["mesh2d"] = ds_temp["mesh2d"].assign_attrs(grid_attrs)

# drop attrs pointing to the removed grid variable (topology is now in mesh2d)
# TODO: this is not possible on the xu.UgridDataset directly
for varn in ds_temp.data_vars:
if "grid" in ds_temp[varn].attrs.keys():
del ds_temp[varn].attrs["grid"]

uds = xu.UgridDataset(ds_temp)
ds = delft3d4_convert_uv(ds)

# TODO: consider using same dims for variables on cell corners and faces
# ds = ds.swap_dims({"M":"MC","N":"NC"})
topology = {"mesh2d":{"x":"M",
"y":"N",
"x_bounds":"xcor_stacked",
"y_bounds":"ycor_stacked",
}
}
uds = xu.UgridDataset.from_structured2d(ds, topology=topology)

return uds


def uda_to_faces(uda : xu.UgridDataArray) -> xu.UgridDataArray:
"""
Interpolates a ugrid variable (xu.DataArray) with a node or edge dimension to the faces by averaging the 3/4 nodes/edges around each face.
Interpolates a ugrid variable (xu.DataArray) with a node or edge dimension
to the faces by averaging the 3/4 nodes/edges around each face.

Parameters
----------
Expand Down Expand Up @@ -530,8 +507,9 @@ def uda_to_faces(uda : xu.UgridDataArray) -> xu.UgridDataArray:
print(f'provided uda/variable "{uda.name}" does not have an node or edge dimension, returning unchanged uda')
return uda

# rechunk to make sure the node/edge dimension is not chunked, otherwise we will
# get "PerformanceWarning: Slicing with an out-of-order index is generating 384539 times more chunks."
# rechunk to make sure the node/edge dimension is not chunked, otherwise we
# will get "PerformanceWarning: Slicing with an out-of-order index is
# generating 384539 times more chunks."
chunks = {dimn_notfaces:-1}
uda = uda.chunk(chunks)

Expand All @@ -551,7 +529,8 @@ def uda_to_faces(uda : xu.UgridDataArray) -> xu.UgridDataArray:
uda_face_allnodes = xu.UgridDataArray(uda_face_allnodes_ds,grid=grid)

# replace nonexistent nodes/edges with nan
uda_face_allnodes = uda_face_allnodes.where(indexer_validbool) #replace all values for fillvalue nodes/edges (-1) with nan
# replace all values for fillvalue nodes/edges (-1) with nan
uda_face_allnodes = uda_face_allnodes.where(indexer_validbool)
# average node/edge values per face
uda_face = uda_face_allnodes.mean(dim=reduce_dim,keep_attrs=True)
#update attrs from node/edge to face
Expand Down
Loading