Source code for thetis.inversion_tools

import firedrake as fd
from firedrake.adjoint import *
import ufl
from .configuration import FrozenHasTraits
from .solver2d import FlowSolver2d
from .utility import create_directory, print_function_value_range, get_functionspace, unfrozen
from .log import print_output
from .diagnostics import HessianRecoverer2D
from .exporter import HDF5Exporter
import abc
import numpy
import h5py
from scipy.interpolate import interp1d
import time as time_mod
import os


[docs] class InversionManager(FrozenHasTraits): """ Class for handling inversion problems and stashing the progress of the associated optimization routines. """ @unfrozen def __init__(self, sta_manager, output_dir='outputs', no_exports=False, real=False, penalty_parameters=[], cost_function_scaling=None, test_consistency=True, test_gradient=True): """ :arg sta_manager: the :class:`StationManager` instance :kwarg output_dir: model output directory :kwarg no_exports: if True, nothing will be written to disk :kwarg real: is the inversion in the Real space? :kwarg penalty_parameters: a list of penalty parameters to pass to the :class:`ControlRegularizationManager` :kwarg cost_function_scaling: global scaling for the cost function. As rule of thumb, it's good to scale the functional to J < 1. :kwarg test_consistency: toggle testing the correctness with which the :class:`ReducedFunctional` can recompute values :kwarg test_gradient: toggle testing the correctness with which the :class:`ReducedFunctional` can recompute gradients """ assert isinstance(sta_manager, StationObservationManager) self.sta_manager = sta_manager self.reg_manager = None self.output_dir = output_dir self.no_exports = no_exports or real self.real = real self.penalty_parameters = penalty_parameters self.cost_function_scaling = cost_function_scaling or fd.Constant(1.0) self.sta_manager.cost_function_scaling = self.cost_function_scaling self.test_consistency = test_consistency self.test_gradient = test_gradient self.outfiles_m = [] self.outfiles_dJdm = [] self.control_exporters = [] self.initialized = False self.J = 0 # cost function value (float) self.J_reg = 0 # regularization term value (float) self.J_misfit = 0 # misfit term value (float) self.dJdm_list = None # cost function gradient (Function) self.m_list = None # control (Function) self.Jhat = None self.m_progress = [] self.J_progress = [] self.J_reg_progress = [] self.J_misfit_progress = [] self.dJdm_progress = [] self.i = 0 self.tic = None self.nb_grad_evals = 0 self.control_coeff_list = [] self.control_list = []
[docs] def initialize(self): if not self.no_exports: if self.real: raise ValueError("Exports are not supported in Real mode.") create_directory(self.output_dir) create_directory(self.output_dir + '/hdf5') for i in range(len(self.control_coeff_list)): self.outfiles_m.append( fd.File(f'{self.output_dir}/control_progress_{i:02d}.pvd')) self.outfiles_dJdm.append( fd.File(f'{self.output_dir}/gradient_progress_{i:02d}.pvd')) self.initialized = True
[docs] def add_control(self, f): """ Add a control field. Can be called multiple times in case of multiparameter optimization. :arg f: Function or Constant to be used as a control variable. """ self.control_coeff_list.append(f) self.control_list.append(Control(f)) if isinstance(f, fd.Function) and not self.no_exports: j = len(self.control_coeff_list) - 1 prefix = f'control_{j:02d}' self.control_exporters.append( HDF5Exporter(f.function_space(), self.output_dir + '/hdf5', prefix) )
[docs] def reset_counters(self): self.nb_grad_evals = 0
[docs] def set_control_state(self, j, djdm_list, m_list): """ Stores optimization state. To call whenever variables are updated. :arg j: error functional value :arg djdm_list: list of gradient functions :arg m_list: list of control coefficents """ self.J = j self.dJdm_list = djdm_list self.m_list = m_list tape = get_working_tape() reg_blocks = tape.get_blocks(tag="reg_eval") self.J_reg = sum([b.get_outputs()[0].saved_output for b in reg_blocks]) misfit_blocks = tape.get_blocks(tag="misfit_eval") self.J_misfit = sum([b.get_outputs()[0].saved_output for b in misfit_blocks])
[docs] def start_clock(self): self.tic = time_mod.perf_counter()
[docs] def stop_clock(self): toc = time_mod.perf_counter() return toc
[docs] def set_initial_state(self, *state): self.set_control_state(*state) self.update_progress()
[docs] def update_progress(self): """ Updates optimization progress and stores variables to disk. To call after successful line searches. """ toc = self.stop_clock() if self.i == 0: for f in self.control_coeff_list: print_function_value_range(f, prefix='Initial') elapsed = '-' if self.tic is None else f'{toc - self.tic:.1f} s' self.tic = toc if not self.initialized: self.initialize() # cost function and gradient norm output djdm = [fd.norm(f) for f in self.dJdm_list] if self.real: controls = [m.dat.data[0] for m in self.m_list] self.m_progress.append(controls) self.J_progress.append(self.J) self.J_reg_progress.append(self.J_reg) self.J_misfit_progress.append(self.J_misfit) self.dJdm_progress.append(djdm) comm = self.control_coeff_list[0].comm if comm.rank == 0 and not self.no_exports: if self.real: numpy.save(f'{self.output_dir}/m_progress', self.m_progress) numpy.save(f'{self.output_dir}/J_progress', self.J_progress) numpy.save(f'{self.output_dir}/J_reg_progress', self.J_reg_progress) numpy.save(f'{self.output_dir}/J_misfit_progress', self.J_misfit_progress) numpy.save(f'{self.output_dir}/dJdm_progress', self.dJdm_progress) if len(djdm) > 10: djdm = f"[{numpy.min(djdm):.4e} .. {numpy.max(djdm):.4e}]" else: djdm = "[" + ", ".join([f"{dj:.4e}" for dj in djdm]) + "]" print_output(f'line search {self.i:2d}: ' f'J={self.J:.3e}, dJdm={djdm}, ' f'grad_ev={self.nb_grad_evals}, duration {elapsed}') if not self.no_exports: # control output for j in range(len(self.control_coeff_list)): m = self.m_list[j] # vtk format o = self.outfiles_m[j] m.rename(self.control_coeff_list[j].name()) o.write(m) # hdf5 format e = self.control_exporters[j] e.export(m) # gradient output for f, o in zip(self.dJdm_list, self.outfiles_dJdm): # store gradient in vtk format f.rename('Gradient') o.write(f) self.i += 1 self.reset_counters()
@property def rf_kwargs(self): """ Default keyword arguments to pass to the :class:`ReducedFunctional` class. """ def gradient_eval_cb(j, djdm, m): self.set_control_state(j, djdm, m) self.nb_grad_evals += 1 return djdm params = { 'derivative_cb_post': gradient_eval_cb, } return params
[docs] def get_cost_function(self, solver_obj, weight_by_variance=False): r""" Get a sum of square errors cost function for the problem: ..math:: J(u) = \sum_{i=1}^{n_{ts}} \sum_{j=1}^{n_{sta}} (u_j^{(i)} - u_{j,o}^{(i)})^2, where :math:`u_{j,o}^{(i)}` and :math:`u_j^{(i)}` denote the observed and computed values at timestep :math:`i`, and :math:`n_{ts}` and :math:`n_{sta}` are the numbers of timesteps and stations, respectively. Regularization terms are included if a :class:`RegularizationManager` instance is provided. :arg solver_obj: the :class:`FlowSolver2d` instance :kwarg weight_by_variance: should the observation data be weighted by the variance at each station? """ assert isinstance(solver_obj, FlowSolver2d) if len(self.penalty_parameters) > 0: self.reg_manager = ControlRegularizationManager( self.control_coeff_list, self.penalty_parameters, self.cost_function_scaling, RSpaceRegularizationCalculator if self.real else HessianRegularizationCalculator) self.J_reg = 0 self.J_misfit = 0 if self.reg_manager is not None: self.J_reg = self.reg_manager.eval_cost_function() self.J = self.J_reg if weight_by_variance: var = fd.Function(self.sta_manager.fs_points_0d) for i, j in enumerate(self.sta_manager.local_station_index): var.dat.data[i] = numpy.var(self.sta_manager.observation_values[j]) self.sta_manager.station_weight_0d.interpolate(1/var) def cost_fn(t): misfit = self.sta_manager.eval_cost_function(t) self.J_misfit += misfit self.J += misfit return cost_fn
@property def reduced_functional(self): """ Create a Pyadjoint :class:`ReducedFunctional` for the optimization. """ if self.Jhat is None: self.Jhat = ReducedFunctional(self.J, self.control_list, **self.rf_kwargs) return self.Jhat
[docs] def stop_annotating(self): """ Stop recording operations for the adjoint solver. This method should be called after the :meth:`iterate` method of :class:`FlowSolver2d`. """ assert self.reduced_functional is not None if self.test_consistency: self.consistency_test() if self.test_gradient: self.taylor_test() pause_annotation()
[docs] def get_optimization_callback(self): """ Get a callback for stashing optimization progress after successful line search. """ def optimization_callback(m): self.update_progress() if not self.no_exports: self.sta_manager.dump_time_series() return optimization_callback
[docs] def minimize(self, opt_method="BFGS", bounds=None, **opt_options): """ Minimize the reduced functional using a given optimization routine. :kwarg opt_method: the optimization routine :kwarg bounds: a list of bounds to pass to the optimization routine :kwarg opt_options: other optimization parameters to pass """ print_output(f'Running {opt_method} optimization') self.reset_counters() self.start_clock() J = float(self.reduced_functional(self.control_coeff_list)) self.set_initial_state(J, self.reduced_functional.derivative(), self.control_coeff_list) if not self.no_exports: self.sta_manager.dump_time_series() return minimize( self.reduced_functional, method=opt_method, bounds=bounds, callback=self.get_optimization_callback(), options=opt_options)
[docs] def consistency_test(self): """ Test that :attr:`reduced_functional` can correctly recompute the objective value, assuming that none of the controls have changed since it was created. """ print_output("Running consistency test") J = self.reduced_functional(self.control_coeff_list) if not numpy.isclose(J, self.J): raise ValueError(f"Consistency test failed (expected {self.J}, got {J})") print_output("Consistency test passed!")
[docs] def taylor_test(self): """ Run a Taylor test to check that the :attr:`reduced_functional` can correctly compute consistent gradients. Note that the Taylor test is applied on the current control values. """ func_list = [] for f in self.control_coeff_list: dc = f.copy(deepcopy=True) func_list.append(dc) minconv = taylor_test(self.reduced_functional, self.control_coeff_list, func_list) if minconv < 1.9: raise ValueError("Taylor test failed") # NOTE: Pyadjoint already prints the testing print_output("Taylor test passed!")
[docs] class StationObservationManager: """ Implements error functional based on observation time series. The functional is the squared sum of error between the model and observations. This object evaluates the model fields at the station locations, interpolates the observations time series to the model time, computes the error functional, and also stores the model's time series data to disk. """ def __init__(self, mesh, output_directory='outputs'): """ :arg mesh: the 2D mesh object. :kwarg output_directory: directory where model time series are stored. """ self.mesh = mesh on_sphere = self.mesh.geometric_dimension() == 3 if on_sphere: raise NotImplementedError('Sphere meshes are not supported yet.') self.cost_function_scaling = fd.Constant(1.0) self.output_directory = output_directory # keep observation time series in memory self.obs_func_list = [] # keep model time series in memory during optimization progress self.station_value_progress = [] # model time when cost function was evaluated self.simulation_time = [] self.model_observation_field = None self.initialized = False
[docs] def register_observation_data(self, station_names, variable, time, values, x, y, start_times=None, end_times=None): """ Add station time series data to the object. The `x`, and `y` coordinates must be such that they allow extraction of model data at the same coordinates. :arg list station_names: list of station names :arg str variable: canonical variable name, e.g. 'elev' :arg list time: array of time stamps, one for each station :arg list values: array of observations, one for each station :arg list x: list of station x coordinates :arg list y: list of station y coordinates :kwarg list start_times: optional start times for the observation periods :kwarg list end_times: optional end times for the observation periods """ self.station_names = station_names self.variable = variable self.observation_time = time self.observation_values = values self.observation_x = x self.observation_y = y num_stations = len(station_names) self._start_times = start_times or -numpy.ones(num_stations)*numpy.inf self._end_times = end_times or numpy.ones(num_stations)*numpy.inf
[docs] def set_model_field(self, function): """ Set the model field that will be evaluated. """ self.model_observation_field = function
[docs] def load_observation_data(self, observation_data_dir, station_names, variable, start_times=None, end_times=None): """ Load observation data from disk. Assumes that observation data were stored with `TimeSeriesCallback2D` during the forward run. For generic case, use `register_observation_data` instead. :arg str observation_data_dir: directory where observation data is stored :arg list station_names: list of station names :arg str variable: canonical variable name, e.g. 'elev' :kwarg list start_times: optional start times for the observation periods :kwarg list end_times: optional end times for the observation periods """ file_list = [ f'{observation_data_dir}/' f'diagnostic_timeseries_{s}_{variable}.hdf5' for s in station_names ] # arrays of time stamps and values for each station observation_time = [] observation_values = [] observation_coords = [] for f in file_list: with h5py.File(f) as h5file: t = h5file['time'][:].flatten() v = h5file[variable][:].flatten() x = h5file.attrs['x'] y = h5file.attrs['y'] observation_coords.append((x, y)) observation_time.append(t) observation_values.append(v) # list of station coordinates observation_x, observation_y = numpy.array(observation_coords).T self.register_observation_data( station_names, variable, observation_time, observation_values, observation_x, observation_y, start_times=start_times, end_times=end_times, ) self.construct_evaluator()
[docs] def update_stations_in_use(self, t): """ Indicate which stations are in use at the current time. An entry of unity indicates use, whereas zero indicates disuse. """ if not hasattr(self, 'obs_start_times'): self.construct_evaluator() in_use = fd.Function(self.fs_points_0d) in_use.dat.data[:] = numpy.array( numpy.bitwise_and( self.obs_start_times <= t, t <= self.obs_end_times ), dtype=float) self.indicator_0d.assign(in_use)
[docs] def construct_evaluator(self): """ Builds evaluators needed to compute the error functional. """ # Create 0D mesh for station evaluation xy = numpy.array((self.observation_x, self.observation_y)).T mesh0d = fd.VertexOnlyMesh(self.mesh, xy) self.fs_points_0d = fd.FunctionSpace(mesh0d, 'DG', 0) self.obs_values_0d = fd.Function(self.fs_points_0d, name='observations') self.mod_values_0d = fd.Function(self.fs_points_0d, name='model values') self.indicator_0d = fd.Function(self.fs_points_0d, name='station use indicator') self.indicator_0d.assign(1.0) self.cost_function_scaling_0d = fd.Constant(0.0, domain=mesh0d) self.cost_function_scaling_0d.assign(self.cost_function_scaling) self.station_weight_0d = fd.Function(self.fs_points_0d, name='station-wise weighting') self.station_weight_0d.assign(1.0) interp_kw = {} if numpy.isfinite(self._start_times).any() or numpy.isfinite(self._end_times).any(): interp_kw.update({'bounds_error': False, 'fill_value': 0.0}) # Construct timeseries interpolator self.station_interpolators = [] self.local_station_index = [] for i in range(self.fs_points_0d.dof_dset.size): # loop over local DOFs and match coordinates to observations # NOTE this must be done manually as VertexOnlyMesh reorders points x_mesh, y_mesh = mesh0d.coordinates.dat.data[i, :] xy_diff = xy - numpy.array([x_mesh, y_mesh]) xy_dist = numpy.sqrt(xy_diff[:, 0]**2 + xy_diff[:, 1]**2) j = numpy.argmin(xy_dist) self.local_station_index.append(j) x, y = xy[j, :] t = self.observation_time[j] v = self.observation_values[j] x_mesh, y_mesh = mesh0d.coordinates.dat.data[i, :] msg = 'bad station location ' \ f'{j} {i} {x} {x_mesh} {y} {y_mesh} {x-x_mesh} {y-y_mesh}' assert numpy.allclose([x, y], [x_mesh, y_mesh]), msg # create temporal interpolator ip = interp1d(t, v, **interp_kw) self.station_interpolators.append(ip) # Process start and end times for observations self.obs_start_times = numpy.array([ self._start_times[i] for i in self.local_station_index ]) self.obs_end_times = numpy.array([ self._end_times[i] for i in self.local_station_index ]) # expressions for cost function self.misfit_expr = self.obs_values_0d - self.mod_values_0d self.initialized = True
[docs] def eval_observation_at_time(self, t): """ Evaluate observation time series at the given time. :arg t: model simulation time :returns: list of observation time series values at time `t` """ self.update_stations_in_use(t) return [float(ip(t)) for ip in self.station_interpolators]
[docs] def eval_cost_function(self, t): """ Evaluate the cost function. Should be called at every export of the forward model. """ assert self.initialized, 'Not initialized, call construct_evaluator first.' assert self.model_observation_field is not None, 'Model field not set.' self.simulation_time.append(t) # evaluate observations at simulation time and stash the result obs_func = fd.Function(self.fs_points_0d) obs_func.dat.data[:] = self.eval_observation_at_time(t) self.obs_func_list.append(obs_func) # compute square error self.obs_values_0d.assign(obs_func) self.mod_values_0d.interpolate(self.model_observation_field, ad_block_tag='observation') s = self.cost_function_scaling_0d * self.indicator_0d * self.station_weight_0d self.J_misfit = fd.assemble(s * self.misfit_expr ** 2 * fd.dx, ad_block_tag='misfit_eval') return self.J_misfit
[docs] def dump_time_series(self): """ Stores model time series to disk. Obtains station time series from the last optimization iteration, and stores the data to disk. The output files are have the format `{odir}/diagnostic_timeseries_progress_{station_name}_{variable}.hdf5` The file contains the simulation time in the `time` array, and the station name and coordinates as attributes. The time series data is stored as a 2D (n_iterations, n_time_steps) array. """ assert self.station_names is not None create_directory(self.output_directory) tape = get_working_tape() blocks = tape.get_blocks(tag='observation') ts_data = [b.get_outputs()[0].saved_output.dat.data for b in blocks] # shape (ntimesteps, nstations) ts_data = numpy.array(ts_data) # append self.station_value_progress.append(ts_data) var = self.variable for ilocal, iglobal in enumerate(self.local_station_index): name = self.station_names[iglobal] # collect time series data, shape (niters, ntimesteps) ts = numpy.array([s[:, ilocal] for s in self.station_value_progress]) fn = f'diagnostic_timeseries_progress_{name}_{var}.hdf5' fn = os.path.join(self.output_directory, fn) with h5py.File(fn, 'w') as hdf5file: hdf5file.create_dataset(var, data=ts) hdf5file.create_dataset('time', data=numpy.array(self.simulation_time)) attrs = { 'x': self.observation_x[iglobal], 'y': self.observation_y[iglobal], 'location_name': name, } hdf5file.attrs.update(attrs)
[docs] class RegularizationCalculator(abc.ABC): """ Base class for computing regularization terms. A derived class should set :attr:`regularization_expr` in :meth:`__init__`. Whenever the cost function is evaluated, the ratio of this expression and the total mesh area will be added. """ @abc.abstractmethod def __init__(self, function, scaling=1.0): """ :arg function: Control :class:`Function` """ self.scaling = scaling self.regularization_expr = 0 self.mesh = function.function_space().mesh() # calculate mesh area (to scale the cost function) self.mesh_area = fd.assemble(fd.Constant(1.0, domain=self.mesh) * fd.dx) self.name = function.name()
[docs] def eval_cost_function(self): expr = self.scaling * self.regularization_expr / self.mesh_area * fd.dx return fd.assemble(expr, ad_block_tag="reg_eval")
[docs] class HessianRegularizationCalculator(RegularizationCalculator): r""" Computes the following regularization term for a cost function involving a control :class:`Function` :math:`f`: .. math:: J = \gamma \| (\Delta x)^2 H(f) \|^2, where :math:`H` is the Hessian of field :math:`f`. """ def __init__(self, function, gamma, scaling=1.0): """ :arg function: Control :class:`Function` :arg gamma: Hessian penalty coefficient """ super().__init__(function, scaling=scaling) # solvers to evaluate the gradient of the control P1v_2d = get_functionspace(self.mesh, "CG", 1, vector=True) P1t_2d = get_functionspace(self.mesh, "CG", 1, tensor=True) gradient_2d = fd.Function(P1v_2d, name=f"{self.name} gradient") hessian_2d = fd.Function(P1t_2d, name=f"{self.name} hessian") self.hessian_calculator = HessianRecoverer2D( function, hessian_2d, gradient_2d) h = fd.CellSize(self.mesh) # regularization expression |hessian|^2 # NOTE this is normalized by the mesh element size # d^2 u/dx^2 * dx^2 ~ du^2 self.regularization_expr = gamma * fd.inner(hessian_2d, hessian_2d) * h**4
[docs] def eval_cost_function(self): self.hessian_calculator.solve() return super().eval_cost_function()
[docs] class RSpaceRegularizationCalculator(RegularizationCalculator): r""" Computes the following regularization term for a cost function involving a control :class:`Function` :math:`f` from an R-space: .. math:: J = \gamma (f - f_0)^2, where :math:`f_0` is a prior, taken to be the initial value of :math:`f`. """ def __init__(self, function, gamma, eps=1.0e-03, scaling=1.0): """ :arg function: Control :class:`Function` :arg gamma: penalty coefficient :kwarg eps: tolerance for normalising by near-zero priors """ super().__init__(function, scaling=scaling) R = function.function_space() if R.ufl_element().family() != "Real": raise ValueError("function must live in R-space") prior = fd.Function(R, name=f"{self.name} prior") prior.assign(function, annotate=False) # Set the prior to the initial value self.regularization_expr = gamma * (function - prior) ** 2 / ufl.max_value(abs(prior), eps)
# NOTE: If the prior is small then dividing by prior**2 puts too much emphasis # on the regularization. Therefore, we divide by abs(prior) instead.
[docs] class ControlRegularizationManager: """ Handles regularization of multiple control fields """ def __init__(self, function_list, gamma_list, penalty_term_scaling=None, calculator=HessianRegularizationCalculator): """ :arg function_list: list of control functions :arg gamma_list: list of penalty parameters :kwarg penalty_term_scaling: Penalty term scaling factor :kwarg calculator: class used for obtaining regularization """ self.reg_calculators = [] assert len(function_list) == len(gamma_list), \ 'Number of control functions and parameters must match' self.reg_calculators = [ calculator(f, g, scaling=penalty_term_scaling) for f, g in zip(function_list, gamma_list) ]
[docs] def eval_cost_function(self): return sum([r.eval_cost_function() for r in self.reg_calculators])