Source code for thetis.coupled_timeintegrator

"""
Time integrators for solving coupled 2D-3D system of equations.
"""
from .utility import *
from . import timeintegrator
from .log import *
from . import rungekutta
from abc import ABC, abstractproperty
import numpy


[docs] class CoupledTimeIntegratorBase(timeintegrator.TimeIntegratorBase): """ Base class for coupled 2D-3D time integrators Provides common functionality for updating diagnostic fields etc. """ def __init__(self, solver): """ :arg solver: :class:`.FlowSolver` object """ self.solver = solver self.options = solver.options self.fields = solver.fields self.timesteppers = AttrDict() def _update_3d_elevation(self): """Projects elevation to 3D""" with timed_stage('aux_elev_3d'): self.solver.fields.elev_domain_2d.assign(self.solver.fields.elev_2d) def _update_vertical_velocity(self): """Solve vertical velocity""" with timed_stage('continuity_eq'): self.solver.w_solver.solve() def _update_moving_mesh(self): """Updates 3D mesh to match elevation field""" if self.options.use_ale_moving_mesh: with timed_stage('aux_mesh_ale'): self.solver.mesh_updater.update_mesh_coordinates() def _update_2d_coupling(self): """Does 2D-3D coupling for the velocity field""" with timed_stage('aux_uv_coupling'): self._remove_depth_average_from_uv_3d() self._update_2d_coupling_term() self._copy_uv_2d_to_3d() def _remove_depth_average_from_uv_3d(self): """Computes depth averaged velocity and removes it from the 3D velocity field""" with timed_stage('aux_uv_coupling'): # compute depth averaged 3D velocity self.solver.uv_averager.solve() # uv -> uv_dav_3d self.solver.extract_surf_dav_uv.solve() # uv_dav_3d -> uv_dav_2d self.solver.copy_uv_dav_to_uv_dav_3d.solve() # uv_dav_2d -> uv_dav_3d # remove depth average from 3D velocity self.fields.uv_3d -= self.fields.uv_dav_3d def _copy_uv_2d_to_3d(self): """Copies uv_2d to uv_dav_3d""" with timed_stage('aux_uv_coupling'): self.solver.copy_uv_to_uv_dav_3d.solve() def _update_2d_coupling_term(self): """Update split_residual_2d field for 2D-3D coupling""" with timed_stage('aux_uv_coupling'): # scale dav uv 2D to be used as a forcing in 2D mom eq. self.fields.split_residual_2d.assign(self.fields.uv_dav_2d) self.fields.split_residual_2d /= self.timesteppers.mom_expl.dt_const def _update_baroclinicity(self): """Computes baroclinic head""" if self.options.use_baroclinic_formulation: compute_baroclinic_head(self.solver) def _update_turbulence(self, t): """ Updates turbulence related fields :arg t: simulation time """ if self.options.use_turbulence: with timed_stage('turbulence'): self.solver.turbulence_model.preprocess() # NOTE psi must be solved first as it depends on tke if 'psi_impl' in self.timesteppers: self.timesteppers.psi_impl.advance(t) if 'tke_impl' in self.timesteppers: self.timesteppers.tke_impl.advance(t) self.solver.turbulence_model.postprocess() def _update_stabilization_params(self): """ Computes Smagorinsky viscosity etc fields """ with timed_stage('aux_stability'): if self.options.use_smagorinsky_viscosity: self.solver.smagorinsky_diff_solver.solve() def _update_all_dependencies(self, t, do_2d_coupling=False, do_vert_diffusion=False, do_ale_update=False, do_stab_params=False, do_turbulence=False): """Default routine for updating all dependent fields after a time step""" self._update_3d_elevation() if do_ale_update: self._update_moving_mesh() if do_2d_coupling: self._update_2d_coupling() self._update_vertical_velocity() self._update_baroclinicity() if do_turbulence: self._update_turbulence(t) if do_vert_diffusion and self.options.use_implicit_vertical_diffusion: with timed_stage('impl_mom_vvisc'): self.timesteppers.mom_impl.advance(t) if self.options.solve_salinity: with timed_stage('impl_salt_vdiff'): self.timesteppers.salt_impl.advance(t) if self.options.solve_temperature: with timed_stage('impl_temp_vdiff'): self.timesteppers.temp_impl.advance(t) if do_stab_params: self._update_stabilization_params()
[docs] class CoupledTimeIntegrator(CoupledTimeIntegratorBase, ABC): """ Base class of mode-split time integrators that use 2D, 3D and implicit 3D time integrators. """ @abstractproperty def integrator_2d(self): """time integrator for 2D equations""" pass @abstractproperty def integrator_3d(self): """time integrator for explicit 3D equations (momentum, tracers)""" pass @abstractproperty def integrator_vert_3d(self): """time integrator for implicit 3D equations (vertical diffusion)""" pass def __init__(self, solver): """ :arg solver: :class:`.FlowSolver` object """ super(CoupledTimeIntegrator, self).__init__(solver) print_output('Coupled time integrator: {:}'.format(self.__class__.__name__)) print_output(' 2D time integrator: {:}'.format(self.integrator_2d.__name__)) print_output(' 3D time integrator: {:}'.format(self.integrator_3d.__name__)) print_output(' 3D implicit time integrator: {:}'.format(self.integrator_vert_3d.__name__)) self._initialized = False self._create_integrators() self.n_stages = self.timesteppers.swe2d.n_stages def _get_vert_diffusivity_functions(self): """ Assign vertical viscosity/diffusivity to implicit/explicit parts """ if self.options.use_implicit_vertical_diffusion: impl_v_visc = self.solver.tot_v_visc.get_sum() expl_v_visc = None impl_v_diff = self.solver.tot_v_diff.get_sum() expl_v_diff = None else: impl_v_visc = None expl_v_visc = self.solver.tot_v_visc.get_sum() impl_v_diff = None expl_v_diff = self.solver.tot_v_diff.get_sum() return impl_v_visc, expl_v_visc, impl_v_diff, expl_v_diff def _create_swe_integrator(self): """ Create time integrator for 2D system """ solver = self.solver momentum_source_2d = solver.fields.split_residual_2d if self.options.momentum_source_2d is not None: momentum_source_2d = solver.fields.split_residual_2d + self.options.momentum_source_2d fields = { 'coriolis': self.options.coriolis_frequency, 'momentum_source': momentum_source_2d, 'volume_source': self.options.volume_source_2d, 'atmospheric_pressure': self.options.atmospheric_pressure, } self.timesteppers.swe2d = self.integrator_2d( solver.equations.sw, self.fields.solution_2d, fields, solver.dt, self.options.timestepper_options.swe_options, solver.bnd_functions['shallow_water']) def _create_mom_integrator(self): """ Create time integrator for 3D momentum equation """ solver = self.solver impl_v_visc, expl_v_visc, impl_v_diff, expl_v_diff = self._get_vert_diffusivity_functions() fields = {'eta': self.fields.elev_domain_2d.view_3d, # FIXME rename elev 'int_pg': self.fields.get('int_pg_3d'), 'uv_depth_av': self.fields.get('uv_dav_3d'), 'w': self.fields.w_3d, 'w_mesh': self.fields.get('w_mesh_3d'), 'viscosity_v': expl_v_visc, 'viscosity_h': self.solver.tot_h_visc.get_sum(), 'source': self.options.momentum_source_3d, 'lax_friedrichs_velocity_scaling_factor': self.options.lax_friedrichs_velocity_scaling_factor, 'coriolis': self.fields.get('coriolis_3d'), } friction_fields = { 'linear_drag_coefficient': self.options.linear_drag_coefficient, 'quadratic_drag_coefficient': self.options.quadratic_drag_coefficient, 'wind_stress': self.fields.get('wind_stress_3d'), 'bottom_roughness': self.options.bottom_roughness, } if not self.solver.options.use_implicit_vertical_diffusion: fields.update(friction_fields) self.timesteppers.mom_expl = self.integrator_3d( solver.equations.momentum, solver.fields.uv_3d, fields, solver.dt, self.options.timestepper_options.explicit_momentum_options, solver.bnd_functions['momentum']) if self.solver.options.use_implicit_vertical_diffusion: fields = {'viscosity_v': impl_v_visc} fields.update(friction_fields) self.timesteppers.mom_impl = self.integrator_vert_3d( solver.equations.vertmomentum, solver.fields.uv_3d, fields, solver.dt, self.options.timestepper_options.implicit_momentum_options, solver.bnd_functions['momentum']) def _create_salt_integrator(self): """ Create time integrator for salinity equation """ solver = self.solver impl_v_visc, expl_v_visc, impl_v_diff, expl_v_diff = self._get_vert_diffusivity_functions() if self.solver.options.solve_salinity: fields = {'elev_3d': self.fields.elev_domain_2d.view_3d, 'uv_3d': self.fields.uv_3d, 'uv_depth_av': self.fields.get('uv_dav_3d'), 'w': self.fields.w_3d, 'w_mesh': self.fields.get('w_mesh_3d'), 'diffusivity_h': self.solver.tot_h_diff.get_sum(), 'diffusivity_v': expl_v_diff, 'source': self.options.salinity_source_3d, 'lax_friedrichs_tracer_scaling_factor': self.options.lax_friedrichs_tracer_scaling_factor, } self.timesteppers.salt_expl = self.integrator_3d( solver.equations.salt, solver.fields.salt_3d, fields, solver.dt, self.options.timestepper_options.explicit_tracer_options, solver.bnd_functions['salt']) if self.solver.options.use_implicit_vertical_diffusion: fields = {'elev_3d': self.fields.elev_domain_2d.view_3d, 'diffusivity_v': impl_v_diff, } self.timesteppers.salt_impl = self.integrator_vert_3d( solver.equations.salt_vdff, solver.fields.salt_3d, fields, solver.dt, self.options.timestepper_options.implicit_tracer_options, solver.bnd_functions['salt']) def _create_temp_integrator(self): """ Create time integrator for temperature equation """ solver = self.solver impl_v_visc, expl_v_visc, impl_v_diff, expl_v_diff = self._get_vert_diffusivity_functions() if self.solver.options.solve_temperature: fields = {'elev_3d': self.fields.elev_domain_2d.view_3d, 'uv_3d': self.fields.uv_3d, 'uv_depth_av': self.fields.get('uv_dav_3d'), 'w': self.fields.w_3d, 'w_mesh': self.fields.get('w_mesh_3d'), 'diffusivity_h': self.solver.tot_h_diff.get_sum(), 'diffusivity_v': expl_v_diff, 'source': self.options.temperature_source_3d, 'lax_friedrichs_tracer_scaling_factor': self.options.lax_friedrichs_tracer_scaling_factor, } self.timesteppers.temp_expl = self.integrator_3d( solver.equations.temp, solver.fields.temp_3d, fields, solver.dt, self.options.timestepper_options.explicit_tracer_options, solver.bnd_functions['temp']) if self.solver.options.use_implicit_vertical_diffusion: fields = {'elev_3d': self.fields.elev_domain_2d.view_3d, 'diffusivity_v': impl_v_diff, } self.timesteppers.temp_impl = self.integrator_vert_3d( solver.equations.temp_vdff, solver.fields.temp_3d, fields, solver.dt, self.options.timestepper_options.implicit_tracer_options, solver.bnd_functions['temp']) def _create_turb_integrator(self): """ Create time integrators for turbulence equations """ solver = self.solver impl_v_visc, expl_v_visc, impl_v_diff, expl_v_diff = self._get_vert_diffusivity_functions() eq_tke_diff = getattr(self.solver.equations, 'tke_diff', None) eq_psi_diff = getattr(self.solver.equations, 'psi_diff', None) eq_tke_adv = getattr(self.solver.equations, 'tke_adv', None) eq_psi_adv = getattr(self.solver.equations, 'psi_adv', None) if eq_tke_diff is not None and eq_psi_diff is not None: fields = {'diffusivity_v': impl_v_diff, 'viscosity_v': impl_v_visc, 'k': solver.fields.tke_3d, 'epsilon': solver.turbulence_model.epsilon, 'shear_freq2': solver.turbulence_model.m2, 'buoy_freq2_neg': solver.turbulence_model.n2_neg, 'buoy_freq2_pos': solver.turbulence_model.n2_pos, 'bottom_roughness': self.options.bottom_roughness, } self.timesteppers.tke_impl = self.integrator_vert_3d( eq_tke_diff, solver.fields.tke_3d, fields, solver.dt, self.options.timestepper_options.implicit_tracer_options, {}) self.timesteppers.psi_impl = self.integrator_vert_3d( eq_psi_diff, solver.fields.psi_3d, fields, solver.dt, self.options.timestepper_options.implicit_tracer_options, {}) if eq_tke_adv is not None and eq_psi_adv is not None: fields = {'elev_3d': self.fields.elev_domain_2d.view_3d, 'uv_3d': self.fields.uv_3d, 'uv_depth_av': self.fields.get('uv_dav_3d'), 'w': self.fields.w_3d, 'w_mesh': self.fields.get('w_mesh_3d'), 'lax_friedrichs_tracer_scaling_factor': self.options.lax_friedrichs_tracer_scaling_factor, } self.timesteppers.tke_expl = self.integrator_3d( eq_tke_adv, solver.fields.tke_3d, fields, solver.dt, self.options.timestepper_options.explicit_tracer_options, {}) self.timesteppers.psi_expl = self.integrator_3d( eq_psi_adv, solver.fields.psi_3d, fields, solver.dt, self.options.timestepper_options.explicit_tracer_options, {}) def _create_integrators(self): """ Creates all time integrators with the correct arguments """ self._create_swe_integrator() self._create_mom_integrator() self._create_salt_integrator() self._create_temp_integrator() self._create_turb_integrator() self.cfl_coeff_3d = self.timesteppers.mom_expl.cfl_coeff self.cfl_coeff_2d = self.timesteppers.swe2d.cfl_coeff
[docs] def set_dt(self, dt, dt_2d): """ Set time step for the coupled time integrator :arg float dt: Time step. This is the master (macro) time step used to march the 3D equations. :arg float dt_2d: Time step for 2D equations. For consistency :attr:`dt_2d` must be an integer fraction of :attr:`dt`. If 2D solver is implicit set :attr:`dt_2d` equal to :attr:`dt`. """ assert numpy.isclose(dt/dt_2d, numpy.round(dt/dt_2d)), \ 'dt_2d is not integer fraction of dt' if dt != dt_2d: raise NotImplementedError('Case dt_2d < dt is not implemented yet') for stepper in sorted(self.timesteppers): self.timesteppers[stepper].set_dt(dt)
[docs] def initialize(self): """ Assign initial conditions to all necessary fields Initial conditions are read from :attr:`fields` dictionary. """ self.timesteppers.swe2d.initialize(self.fields.solution_2d) self.timesteppers.mom_expl.initialize(self.fields.uv_3d) if self.options.use_implicit_vertical_diffusion: self.timesteppers.mom_impl.initialize(self.fields.uv_3d) if self.options.solve_salinity: self.timesteppers.salt_expl.initialize(self.fields.salt_3d) if self.options.use_implicit_vertical_diffusion: self.timesteppers.salt_impl.initialize(self.fields.salt_3d) if self.options.solve_temperature: self.timesteppers.temp_expl.initialize(self.fields.temp_3d) if self.options.use_implicit_vertical_diffusion: self.timesteppers.temp_impl.initialize(self.fields.temp_3d) if 'psi_impl' in self.timesteppers: self.timesteppers.psi_impl.initialize(self.fields.psi_3d) if 'tke_impl' in self.timesteppers: self.timesteppers.tke_impl.initialize(self.fields.tke_3d) if 'psi_expl' in self.timesteppers: self.timesteppers.psi_expl.initialize(self.fields.psi_3d) if 'tke_expl' in self.timesteppers: self.timesteppers.tke_expl.initialize(self.fields.tke_3d) self._initialized = True
[docs] class CoupledLeapFrogAM3(CoupledTimeIntegrator): """ Leap-Frog Adams-Moulton 3 time integrator for coupled 2D-3D problem This is an ALE time integrator. Implementation follows the SLIM time integrator by Karna et al (2013) Karna, et al. (2013). A baroclinic discontinuous Galerkin finite element model for coastal flows. Ocean Modelling, 61(0):1-20. http://dx.doi.org/10.1016/j.ocemod.2012.09.009 """ integrator_2d = rungekutta.DIRK22 integrator_3d = timeintegrator.LeapFrogAM3 integrator_vert_3d = rungekutta.BackwardEuler @PETSc.Log.EventDecorator("thetis.CoupledLeapFrogAM3.__init__") def __init__(self, solver): super(CoupledLeapFrogAM3, self).__init__(solver) self.elev_domain_old_2d = Function(self.fields.elev_domain_2d) self.uv_old_2d = Function(self.fields.uv_2d) self.uv_new_2d = Function(self.fields.uv_2d)
[docs] @PETSc.Log.EventDecorator("thetis.CoupledLeapFrogAM3.advance") def advance(self, t, update_forcings=None, update_forcings3d=None): """ Advances the equations for one time step :arg float t: simulation time :kwarg update_forcings: Optional user-defined function that takes simulation time and updates time-dependent boundary conditions of the 2D equations. :kwarg update_forcings3d: Optional user defined function that updates boundary conditions of the 3D equations """ if not self._initialized: self.initialize() # ------------------------------------------------- # Prediction step # - from t_{n-1/2} to t_{n+1/2} # - RHS evaluated at t_{n} # - Forward Euler step in fixed domain Omega_n # ------------------------------------------------- if self.options.use_ale_moving_mesh: self.fields.w_mesh_3d.assign(0.0) with timed_stage('salt_eq'): if self.options.solve_salinity: self.timesteppers.salt_expl.predict() if self.options.use_limiter_for_tracers: self.solver.tracer_limiter.apply(self.fields.salt_3d) with timed_stage('temp_eq'): if self.options.solve_temperature: self.timesteppers.temp_expl.predict() if self.options.use_limiter_for_tracers: self.solver.tracer_limiter.apply(self.fields.temp_3d) with timed_stage('turb_advection'): if 'psi_expl' in self.timesteppers: self.timesteppers.psi_expl.predict() if 'tke_expl' in self.timesteppers: self.timesteppers.tke_expl.predict() with timed_stage('momentum_eq'): self.timesteppers.mom_expl.predict() if self.options.use_limiter_for_velocity: self.solver.uv_limiter.apply(self.fields.uv_3d) # dependencies for 2D update self._update_2d_coupling() self._update_baroclinicity() # update 2D if self.options.use_ale_moving_mesh: self.solver.mesh_updater.compute_mesh_velocity_begin() self.uv_old_2d.assign(self.fields.uv_2d) with timed_stage('mode2d'): self.timesteppers.swe2d.advance(t, update_forcings) if self.options.use_ale_moving_mesh: self.solver.mesh_updater.compute_mesh_velocity_finalize() self.uv_new_2d.assign(self.fields.uv_2d) # set 3D elevation to half step gamma = self.timesteppers.mom_expl.gamma self.elev_domain_old_2d.assign(self.fields.elev_domain_2d) self.solver.fields.elev_domain_2d.assign(self.solver.fields.elev_2d) self.fields.elev_domain_2d *= (0.5 + 2*gamma) self.fields.elev_domain_2d += (0.5 - 2*gamma)*self.elev_domain_old_2d # correct uv_3d to uv_2d at t_{n+1/2} self.fields.uv_2d *= (0.5 + 2*gamma) self.fields.uv_2d += (0.5 - 2*gamma)*self.uv_old_2d self._update_2d_coupling() self.fields.uv_2d.assign(self.uv_new_2d) # restore self._update_vertical_velocity() self._update_baroclinicity() # ------------------------------------------------- # Correction step # - from t_{n} to t_{n+1} # - RHS evaluated at t_{n+1/2} # - Forward Euler ALE step from Omega_n to Omega_{n+1} # ------------------------------------------------- with timed_stage('salt_eq'): if self.options.solve_salinity: self.timesteppers.salt_expl.eval_rhs() with timed_stage('temp_eq'): if self.options.solve_temperature: self.timesteppers.temp_expl.eval_rhs() with timed_stage('turb_advection'): if 'psi_expl' in self.timesteppers: self.timesteppers.psi_expl.eval_rhs() if 'tke_expl' in self.timesteppers: self.timesteppers.tke_expl.eval_rhs() with timed_stage('momentum_eq'): self.timesteppers.mom_expl.eval_rhs() self._update_3d_elevation() self._update_moving_mesh() with timed_stage('salt_eq'): if self.options.solve_salinity: self.timesteppers.salt_expl.correct() if self.options.use_limiter_for_tracers: self.solver.tracer_limiter.apply(self.fields.salt_3d) with timed_stage('temp_eq'): if self.options.solve_temperature: self.timesteppers.temp_expl.correct() if self.options.use_limiter_for_tracers: self.solver.tracer_limiter.apply(self.fields.temp_3d) with timed_stage('turb_advection'): if 'psi_expl' in self.timesteppers: self.timesteppers.psi_expl.correct() if 'tke_expl' in self.timesteppers: self.timesteppers.tke_expl.correct() with timed_stage('momentum_eq'): self.timesteppers.mom_expl.correct() if self.options.use_limiter_for_velocity: self.solver.uv_limiter.apply(self.fields.uv_3d) if self.options.use_implicit_vertical_diffusion: self._update_2d_coupling() self._update_baroclinicity() self._update_turbulence(t) if self.options.solve_salinity: with timed_stage('impl_salt_vdiff'): self.timesteppers.salt_impl.advance(t) if self.options.solve_temperature: with timed_stage('impl_temp_vdiff'): self.timesteppers.temp_impl.advance(t) with timed_stage('impl_mom_vvisc'): self.fields.uv_3d += self.fields.uv_dav_3d self.timesteppers.mom_impl.advance(t) self.fields.uv_3d -= self.fields.uv_dav_3d self._update_baroclinicity() self._update_vertical_velocity() self._update_stabilization_params() else: self._update_2d_coupling() self._update_baroclinicity() self._update_vertical_velocity() self._update_stabilization_params()
[docs] class CoupledTwoStageRK(CoupledTimeIntegrator): """ Coupled time integrator based on SSPRK(2,2) scheme This ALE time integration method uses SSPRK(2,2) scheme to advance the 3D equations and a compatible implicit Trapezoid method to advance the 2D equations. Backward Euler scheme is used for vertical diffusion. """ integrator_2d = rungekutta.ESDIRKTrapezoid integrator_3d = timeintegrator.SSPRK22ALE integrator_vert_3d = rungekutta.BackwardEuler @PETSc.Log.EventDecorator("thetis.CoupledTwoStageRK.__init__") def __init__(self, solver): super(CoupledTwoStageRK, self).__init__(solver) # allocate CG elevation fields for storing mesh geometry self.elev_fields = [] for i in range(self.n_stages): e = Function(self.fields.elev_cg_2d) self.elev_fields.append(e)
[docs] @PETSc.Log.EventDecorator("thetis.CoupledTwoStageRK.store_elevation") def store_elevation(self, istage): """ Store current elevation field for computing mesh velocity Must be called before updating the 2D mode. :arg istage: stage of the Runge-Kutta iteration :type istage: int """ if self.options.use_ale_moving_mesh: self.solver.mesh_updater.compute_mesh_velocity_begin() self.elev_fields[istage].assign(self.fields.elev_cg_2d)
[docs] @PETSc.Log.EventDecorator("thetis.CoupledTwoStageRK.compute_mesh_velocity") def compute_mesh_velocity(self, istage): """ Computes mesh velocity for stage i Must be called after updating the 2D mode. :arg istage: stage of the Runge-Kutta iteration :type istage: int """ if self.options.use_ale_moving_mesh: self.solver.mesh_updater.compute_mesh_velocity_begin() current_elev = self.fields.elev_cg_2d if istage == 0: # compute w_mesh at surface as (elev^{(1)} - elev^{n})/dt w_s = (current_elev - self.elev_fields[0])/self.solver.dt else: # compute w_mesh at surface as (2*elev^{n+1} - elev^{(1)} - elev^{n})/dt w_s = (2*current_elev - self.elev_fields[1] - self.elev_fields[0])/self.solver.dt self.solver.mesh_updater.compute_mesh_velocity_finalize( w_mesh_surf_expr=w_s)
[docs] @PETSc.Log.EventDecorator("thetis.CoupledTwoStageRK.advance") def advance(self, t, update_forcings=None, update_forcings3d=None): """ Advances the equations for one time step :arg float t: simulation time :kwarg update_forcings: Optional user-defined function that takes simulation time and updates time-dependent boundary conditions of the 2D equations. :kwarg update_forcings3d: Optional user defined function that updates boundary conditions of the 3D equations """ if not self._initialized: self.initialize() for i_stage in range(self.n_stages): # solve 2D mode self.store_elevation(i_stage) with timed_stage('mode2d'): self.timesteppers.swe2d.solve_stage(i_stage, t, update_forcings) self.compute_mesh_velocity(i_stage) # solve 3D mode: preprocess in old mesh with timed_stage('salt_eq'): if self.options.solve_salinity: self.timesteppers.salt_expl.prepare_stage(i_stage, t, update_forcings3d) with timed_stage('temp_eq'): if self.options.solve_temperature: self.timesteppers.temp_expl.prepare_stage(i_stage, t, update_forcings3d) with timed_stage('turb_advection'): if 'psi_expl' in self.timesteppers: self.timesteppers.psi_expl.prepare_stage(i_stage, t, update_forcings3d) if 'tke_expl' in self.timesteppers: self.timesteppers.tke_expl.prepare_stage(i_stage, t, update_forcings3d) with timed_stage('momentum_eq'): self.timesteppers.mom_expl.prepare_stage(i_stage, t, update_forcings3d) # update mesh self._update_3d_elevation() self._update_moving_mesh() # solve 3D mode with timed_stage('salt_eq'): if self.options.solve_salinity: self.timesteppers.salt_expl.solve_stage(i_stage) if self.options.use_limiter_for_tracers: self.solver.tracer_limiter.apply(self.fields.salt_3d) with timed_stage('temp_eq'): if self.options.solve_temperature: self.timesteppers.temp_expl.solve_stage(i_stage) if self.options.use_limiter_for_tracers: self.solver.tracer_limiter.apply(self.fields.temp_3d) with timed_stage('turb_advection'): if 'psi_expl' in self.timesteppers: self.timesteppers.psi_expl.solve_stage(i_stage) if 'tke_expl' in self.timesteppers: self.timesteppers.tke_expl.solve_stage(i_stage) with timed_stage('momentum_eq'): self.timesteppers.mom_expl.solve_stage(i_stage) if self.options.use_limiter_for_velocity: self.solver.uv_limiter.apply(self.fields.uv_3d) last_stage = i_stage == self.n_stages - 1 if last_stage: # compute final prognostic variables self._update_2d_coupling() if self.options.use_implicit_vertical_diffusion: if self.options.solve_salinity: with timed_stage('impl_salt_vdiff'): self.timesteppers.salt_impl.advance(t) if self.options.solve_temperature: with timed_stage('impl_temp_vdiff'): self.timesteppers.temp_impl.advance(t) with timed_stage('impl_mom_vvisc'): # compute full velocity self.fields.uv_3d += self.fields.uv_dav_3d self.timesteppers.mom_impl.advance(t) self.fields.uv_3d -= self.fields.uv_dav_3d # compute final diagnostic fields self._update_baroclinicity() self._update_vertical_velocity() # update parametrizations self._update_turbulence(t) self._update_stabilization_params() else: # update variables that explict solvers depend on self._update_2d_coupling() self._update_baroclinicity() self._update_vertical_velocity()