"""The ProblemExec provides a uniform set of options for managing the state of the system and its solvables, establishing the selection of combos or de/active attributes to Solvables. Once once created any further entracnces to ProblemExec will return the same instance until finally the last exit is called.
The ProblemExec class allows entrance to a its context to the same instance until finally the last exit is called. The first entrance to the context will create the instance, each subsequent entrance will return the same instance. The ProblemExec arguments are set the first time and remove keyword arguments from the input dictionary (passed as a dict ie stateful) to subsequent methods.
This isn't technically a singleton pattern, but it does provide a similar interface. Instead mutliple problem instances will be clones of the first instance, with the optional difference of input/output/event criteria. The first instance will be returned by each context entry, so for that reason it may always appear to have same instance, however each instance is unique in a recusive setting so it may record its own state and be reverted to its own state as per the options defined.
#TODO: allow update of kwargs on re-entrance
## Example:
.. code-block:: python
#Application code (arguments passed in kw)
with ProblemExec(sys,combos='default',slv_vars'*',**kw) as pe:
pe._sys_refs #get the references and compiled problem
for i in range(10):
pe.solve_min(pe.Xref,pe.Yref,**other_args)
pe.set_checkpoint() #save the state of the system
pe.save_data()
#Solver Module (can use without knowledge of the runtime system)
with ProblemExec(sys,{},Xnew=Xnext,ctx_fail_new=True) as pe:
#do revertable math on the state of the system without concern for the state of the system
...
# Combos Selection
By default no arguments run will select all active items with combo="default". The `combos` argument can be used to select a specific set of combos, a outer select. From this set, the `ign_combos` and `only_combos` arguments can be used to ignore or select specific combos based on exclusion or inclusion respectively.
# Parameter Name Selection
The `slv_vars` argument can be used to select a specific set of solvables. From this set, the `ign_vars` and `only_vars` arguments can be used to ignore or select specific solvables based on exclusion or inclusion respectively. The `add_vars` argument can be used to add a specific set of solvables to the solver.
# Active Mode Handiling
The `only_active` argument can be used to select only active items. The `activate` and `deactivate` arguments can be used to activate or deactivate specific solvables.
`add_obj` can be used to add an objective to the solver.
# Root Parameter Determination:
Any session's may access root session parameters defined in the `root_parameters` dictionary like `converged`, or `data` can be accessed by calling `session.<parm>` this eventually calls `root._<parm>` via the __getattr__ method.
# Exit Mode Handling
The ProblemExec supports the following exit mode handling vars:
- `fail_revert`: Whether to raise an error if no solvables are selected. Default is True.
- `revert_last`: Whether to revert the last change. Default is True.
- `revert_every`: Whether to revert every change. Default is True.
- `exit_on_failure`: Whether to exit on first failure. Default is True.
These vars control the behavior of the ProblemExec when an error occurs or when no solvables are selected.
"""
# TODO: define the
from engforge.logging import LoggingMixin
from engforge.system_reference import Ref
from engforge.dataframe import DataframeMixin, pandas
from engforge.solver_utils import *
from engforge.env_var import EnvVariable
import weakref
from scipy.integrate import solve_ivp
from collections import OrderedDict
import numpy as np
import pandas as pd
import datetime
[docs]
class ProbLog(LoggingMixin):
pass
log = ProbLog()
import uuid
# TODO: implement add_vars feature, ie it creates a solver variable, or activates one if it doesn't exist from in system.heirarchy.format
# TODO: define the dataframe / data storage feature
min_opt = {"finite_diff_rel_step": 0.25, "maxiter": 10000}
min_kw_dflt = {
"tol": 1e-10,
"method": "SLSQP",
"jac": "cs",
"hess": "cs",
"options": min_opt,
}
# The KW Defaults for Solver via kw_dict
# IMPORTANT:!!! these group parameter names by behavior, they are as important as the following class, add/remove variables with caution
# these choices affect how solver-items are selected and added to the solver
slv_dflt_options = dict(
combos="default",
ign_combos=None,
only_combos=None,
add_obj=True,
slv_vars="*",
add_vars=None,
ign_vars=None,
only_vars=None,
only_active=True,
activate=None,
deactivate=None,
dxdt=None,
weights=None,
both_match=True,
obj=None,
)
# KW Defaults for the local context (what saves state for contexts / reverts ect)
dflt_parse_kw = dict(
fail_revert=True,
revert_last=True,
revert_every=True,
exit_on_failure=True,
pre_exec=True,
post_exec=True,
opt_fail=True,
level_name="top",
post_callback=None,
success_thresh=10,
copy_system=False,
run_solver=False,
min_kw=None,
save_mode="all",
x_start=None,
save_on_exit=False,
enter_refresh=True,
)
# can be found on session._<parm> or session.<parm>
root_defined = dict(
last_time=0,
time=0,
dt=0,
update_refs=None,
post_update_refs=None,
sys_refs=None,
slv_kw=None,
minimizer_kw=None,
data=None,
weights=None,
dxdt=None,
run_start=None,
run_end=None,
run_time=None,
all_refs=None,
num_refs=None,
converged=None,
comp_changed=False,
)
save_modes = ["vars", "nums", "all", "prob"]
transfer_kw = ["system", "_dxdt"]
root_possible = list(root_defined.keys()) + list("_" + k for k in root_defined.keys())
# TODO: output options extend_dataframe=True,return_dataframe=True,condensed_dataframe=True,return_system=True,return_problem=True,return_df=True,return_data=True
# TODO: connect save_data() output to _data table.
# TODO: move dataframe mixin here, system should return a dataframe, and the problem should be able to save data to the dataframe, call this class Problem(). With default behavior it could seem like a normal dataframe is returned on problem.return(*state,exit,revert...)
# Special exception classes handled in exit
[docs]
class IllegalArgument(Exception):
"""an exception to exit the problem context as specified"""
pass
[docs]
class ProblemExit(Exception):
"""an exception to exit the problem context, without error"""
revert: bool
prob: "ProblemExec"
def __init__(self, prob: "ProblemExec", revert: bool = None):
self.revert = revert
self.prob = prob
def __str__(self) -> str:
return f"ProblemExit[{self.prob}|rvt={self.revert}]"
[docs]
class ProblemExitAtLevel(ProblemExit):
"""an exception to exit the problem context, without error"""
level: str
def __init__(self, prob: "ProblemExec", level: str, revert=None):
assert level is not None, "level must be defined"
assert isinstance(level, str), "level must be a string"
self.prob = prob
self.level = level.lower()
self.revert = revert
def __str__(self) -> str:
return f"ProblemExit[{self.prob}|lvl={self.level}|rvt={self.revert}]"
# TODO: develop subproblem strategy (multiple root cache problem in class cache)
# TODO: determine when components are updated, and refresh the system references accordingly.
# TODO: Map attributes/properties by component key and then autofix refs! (this is a big one), no refresh required. Min work
# TODO: component graph with pyee anything changed system to make a lazy observer system.
# TODO: plot levels to manage report output (with pdf graph publishing)
[docs]
class ProblemExec:
"""
Represents the execution context for a problem in the system. The ProblemExec class provides a uniform set of options for managing the state of the system and its solvables, establishing the selection of combos or de/active attributes to Solvables. Once once created any further entracnces to ProblemExec will return the same instance until finally the last exit is called.
## params:
- _problem_id: uuid for subproblems, or True for top level, None means uninitialized
"""
full_update = True # TODO: cant justify setting this to false for performance gains. accuracy comes first. see event based update
# TODO: convert this to a system based cache where there is a unique problem for each system instance. On subprobem copy a system and add o dictionary.
class_cache = None # ProblemExec is assigned below
# this class, wide, dont redefine it pls
problems_dict = weakref.WeakValueDictionary()
system: "System"
session: "ProblemExec"
session_id = None
# problem state / per level
problem_id = None
entered: bool = False
exited: bool = False
# solution control (point to singleton/subproblem via magic getattr)
_last_time: float = 0
_time: float = 0
_dt: float = 0
_update_refs: dict
_post_update_refs: dict
_sys_refs: dict
_slv_kw: dict
_minimizer_kw: dict
_data: list
_weights: dict
x_start: dict
_dxdt: float
_run_start: float
_run_end: float
_run_time: float
_converged: bool
# Interior Context Options
enter_refresh: bool = (
True # TODO: allow this off (or lower impact) with event system
)
save_on_exit: bool = False
save_mode: str = "all"
level_name: str = None # target this context with the level name
level_number: int = 0 # TODO: keep track of level on the global context
pre_exec: bool = True
post_exec: bool = True
fail_revert: bool = True
revert_last: bool = True
revert_every: bool = True
exit_on_failure: bool = True
opt_fail: bool = True
raise_on_unknown: bool = True
copy_system: bool = False
success_thresh = (
1e6 # if system has `success_thresh` it will be assigned to the context
)
post_callback: callable = (
None # callback that will be called on the system each time it is reverted, it should take args(system,current_problem_exec)
)
run_solver: bool = (
False # for transient #i would love this to be=true, but there's just too much possible variation in application to make it so without some kind of control / continuity strategy. Dynamics are natural responses anyways, so solver use should be an advanced case for now (MPC/Filtering/ect later)
)
def __getattr__(self, name):
"""
This is a special method that is called when an attribute is not found in the usual places, like when interior contexts (anything not the root (session_id=True)) are created that dont have the top level's attributes. some attributes will look to the parent session
"""
# interior context lookup (when in active context, ie session exists)
if hasattr(self.class_cache, "session") and name in root_possible:
# revert to the parent session
if self.session_id != True and name.startswith("_"):
# self.info(f'get parent private {name}')
return getattr(self.class_cache.session, name)
elif name in root_defined:
# self.info(f'get parent public {name}')
return getattr(self.class_cache.session, "_" + name)
if name in root_defined: # public interface
# self.info(f'get root fallback {name}')
return self.__getattribute__("_" + name)
# Default behaviour
return self.__getattribute__(name)
def __init__(self, system, kw_dict=None, Xnew=None, ctx_fail_new=False, **opts):
"""
Initializes the ProblemExec.
#TODO: exit system should abide by update / signals options
#TODO: provide data storage options for dataframe / table storage history/ record keeping (ss vs transient data)
#TODO: create an option to copy the system and run operations on it, and options for applying the state from the optimized copy to the original system
:param system: The system to be executed.
:param Xnew: The new state of the system to set at the start o the problem context, and will revert after the problem exits, optional
:param ctx_fail_new: Whether to raise an error if no execution context is available, use in utility methods ect. Default is False.
:param kw_dict: A keyword argument dictionary to be parsed for solver options, and removed from the outer context. Changes are made to this dictionary, so they are removed automatically from the outer context, and thus no longer passed to interior vars.
:param dxdt: The dynamics integration method. Default is None meaning that dynamic vars are not considered for minimization unless otherwise specified. Steady State can be specified by dxdt=0 all dynamic vars are considered as solver variables, with the constraint that their rate of change is zero. If a dictionary is passed then the dynamic vars are considered as solver variables, with the constraint that their rate of change is equal to the value in the dictionary, and all other unspecified rates are zero (steady).
#### Solver Selection Options
:param combos: The selection of combos. Default is '*' (select all).
:param ign_combos: The combos to be ignored.
:param only_combos: The combos to be selected.
:param add_obj: Whether to add an objective to the solver. Default is True.
:param slv_vars: The selection of solvables. Default is '*' (select all).
:param add_vars: The solvables to be added to the solver.
:param ign_vars: The solvables to be ignored.
:param only_vars: The solvables to be selected.
:param only_active: Whether to select only active items. Default is True.
:param activate: The solvables to be activated.
:param deactivate: The solvables to be deactivated.
:param fail_revert: Whether to raise an error if no solvables are selected. Default is True.
:param revert_last: Whether to revert the last change. Default is True.
:param revert_every: Whether to revert every change. Default is True.
:param exit_on_failure: Whether to exit on failure, or continue on. Default is True.
"""
self.dynamics_updated = False # it is known
if kw_dict is None:
# kw_dict is stateful so you can mix system & context args together, and ensure context args are removed. in the case this is unused, we'll create an empty dict to avoid errors
kw_dict = {}
# storage optoins
if opts.pop("persist", False) or kw_dict.pop("persist", False):
self.persist_contexts()
# temp solver storage #TODO
# self.solver_hist = expiringdict.ExpiringDict(100, 60)
if self.log_level < 5:
if hasattr(self.class_cache, "session"):
self.debug(
f"subctx{self.level_number}| keywords: {kw_dict} and misc: {opts}"
)
else:
self.debug(f"context| keywords: {kw_dict} and misc: {opts}")
# special cases for parsing
# parse the options to change behavior of the context
level_name = None
if opts and "level_name" in opts:
level_name = opts.pop("level_name").lower()
if kw_dict and "level_name" in kw_dict:
level_name = kw_dict.pop("level_name").lower()
# solver min-args wrt defaults
min_kw = None
if opts and "min_kw" in opts:
min_kw = kw_dict.pop("min_kw")
if kw_dict and "min_kw" in kw_dict:
min_kw = kw_dict.pop("min_kw")
mkw = min_kw_dflt.copy()
if min_kw is None:
min_kw = mkw
else:
mkw.update(min_kw)
self._minimizer_kw = mkw
# Merge kwdict(stateful) and opts (current level)
# solver vars should be static for a problem and subcontexts, however the default vars can change. Subproblems allow for the solver vars to be changed on its creation.
opt_in, opt_out = {}, {}
if opts:
# these go to the context instance optoins
opt_in = {k: v for k, v in opts.items() if k in dflt_parse_kw}
# these go to system establishment
opt_out = {k: v for k, v in opts.items() if k not in opt_in}
if kw_dict is None:
kw_dict = {}
else:
# these go to the context instance optoins
kw_in = {k: v for k, v in kw_dict.items() if k in dflt_parse_kw}
opt_in.update(kw_in)
kw_out = {k: v for k, v in kw_dict.items() if k not in opt_in}
# these go to system establishment
opt_out.update(kw_out)
# remove problem options from dict (otherwise passed along to system!)
for k in kw_in:
kw_dict.pop(k)
# Define the handiling of rate integrals
if "dxdt" in opts and opts["dxdt"] is not None:
dxdt = opts.pop("dxdt")
if "dxdt" in kw_dict and kw_dict["dxdt"] is not None:
dxdt = kw_dict.pop("dxdt")
else:
dxdt = None # by default dont consider dynamics
if dxdt is not None and dxdt is not False:
if dxdt == 0:
pass
elif dxdt is True:
pass
elif isinstance(dxdt, dict): # dxdt is a dictionary
# provide a set of values or function to have the solver solve for
pass
else:
raise IllegalArgument(f"bad dxdt value {dxdt}")
# Check if there is an existing problem session, else create it
if hasattr(self.class_cache, "session"):
# mirror the state of session (exactly)
copy_vals = {
k: v
for k, v in self.class_cache.session.__dict__.items()
if k in dflt_parse_kw or k in transfer_kw
}
self.__dict__.update(copy_vals)
self._problem_id = int(uuid.uuid4())
self.problems_dict[self._problem_id] = self # carry that weight
self.session_id = int(uuid.uuid4())
self.class_cache.session._prob_levels[self.level_name] = self
# error if the system is different (it shouldn't be!)
if self.system is not system:
# TODO: subproblems allow different systems, but the top level should be the same
# idea - use (system,pid) as key for problems_dict, (system,True) would be root problem. This breaks checking for `class_cache.session` though one could gather that from the root problem key`
raise IllegalArgument(
f"somethings wrong! change of comp! {self.system} -> {system}"
)
# modify things from the input
if level_name is None:
# your new id
self.level_name = "ctx_" + str(int(self._problem_id))[0:15]
else:
self.level_name = level_name
if opt_in:
self.__dict__.update(opt_in) # level options ASAP
self.temp_state = Xnew # input state exception to this
if log.log_level < 5:
self.msg(f"setting execution context with {opt_in}| {opt_out}")
# each state request to be reverted, then we need to store the state of each execution context overriding the outer context x_start
self.set_checkpoint()
elif ctx_fail_new:
raise IllegalArgument(f"no execution context available")
else:
# add the prob options to the context and establish system
self.__dict__.update(opt_in)
self._problem_id = True # this is the top level
self.problems_dict[self._problem_id] = self # carry that weight
self._prob_levels = {}
self._converged = None
self._dxdt = dxdt
self.reset_data()
# supply the level name default as top if not set
if level_name is None:
self.level_name = "top"
else:
self.level_name = level_name
self.temp_state = Xnew
self.establish_system(system, kw_dict=kw_dict, **opt_out)
# Finally we record where we started!
self.set_checkpoint()
if log.log_level < 10:
self.info(f"new execution context for {system}| {opts} | {self._slv_kw}")
elif log.log_level <= 3:
self.msg(f"new execution context for {system}| {self._slv_kw}")
[docs]
def reset_data(self):
"""reset the data storage"""
# the data storage!!
# TODO: add buffer type, or disk cache
self._data = {} # index:row_dict
self._index = 0 # works for time or index
[docs]
def establish_system(self, system, kw_dict, **kwargs):
"""caches the system references, and parses the system arguments"""
from engforge.solver import SolverMixin
from engforge.system import System
if self.copy_system:
system = system.copy_config_at_state()
# place me here after system has been modified
self.system = system
# pass args without creating singleton (yet)
self.session_id = int(uuid.uuid4())
self._run_start = datetime.datetime.now()
self.name = system.name + "-" + str(self.session_id)[:8]
if log.log_level < 5:
self.info(f"establish {system}| {kw_dict} {kwargs}")
assert isinstance(
self.system, SolverMixin
), "only solveable interfaces are supported for execution context"
self.system._last_context = self # set the last context to this one
if hasattr(self.system, "success_thresh") and isinstance(
self.system.success_thresh, (int, float)
):
self.success_thresh = self.system.success_thresh
# Extract solver vars and set them on this object, they will be distributed to any new further execution context's via monkey patch above
in_kw = self.get_extra_kws(kwargs, slv_dflt_options, use_defaults=False)
self._slv_kw = self.get_extra_kws(kw_dict, slv_dflt_options, rmv=True)
self._slv_kw.update(in_kw) # update with input!
self.refresh_references()
# Get solver weights
self._weights = self._slv_kw.get("weights", None)
# Grab inputs and set to system
for k, v in dflt_parse_kw.items():
if k in self._slv_kw:
setattr(self, k, self._slv_kw[k])
if log.log_level < 5:
self.msg(f"established sys context: {self} {self._slv_kw}")
@property
def sesh(self):
"""caches the property for the session"""
if hasattr(self, "inst_sesh"):
return self.inst_sesh
sesh = self.get_sesh()
return sesh
[docs]
def get_sesh(self, sesh=None):
"""get the session"""
out = sesh
if not sesh:
if hasattr(self.class_cache, "session"):
out = self.class_cache.session
elif self._problem_id == True:
out = self
if out:
self.inst_sesh = out
return out
# @classmethod
# def cls_get_sesh(cls, sesh=None):
# """get the session"""
# out = sesh
# if not sesh:
# if hasattr(cls.class_cache, "session"):
# out = self.class_cache.session
# elif self._problem_id == True:
# out = self
# if out:
# self.inst_sesh = out
# return out
# @property
# def index(self):
# sesh = self.get_sesh()
# if not sesh._data:
# return 0
# else:
# return max(list(sesh._data.keys()))
#
# @property
# def last_index(self):
# sesh = self.index
# if sesh == 0:
# return None
# else:
# return sesh - 1
#
# @property
# def next_index(self):
# sesh = self.index
# if sesh == 0:
# return 1
# else:
# return sesh + 1
# Update Methods
[docs]
def refresh_references(self, sesh=None):
"""refresh the system references"""
if sesh is None:
sesh = self.sesh
if self.log_level < 5:
self.warning(f"refreshing system references")
sesh.full_refresh(sesh=sesh)
sesh.min_refresh(sesh=sesh)
def update_methods(self, sesh=None):
# Get the update method refs
sesh = sesh if sesh is not None else self.sesh
sesh._update_refs = sesh.system.collect_update_refs()
# TODO: find subsystems that are not-subsolvers and execute them
sesh._post_update_refs = sesh.system.collect_post_update_refs()
sesh.update_dynamics(sesh=sesh)
def update_dynamics(self, sesh=None):
# apply changes to the dynamics models
sesh = sesh if sesh is not None else self.sesh
if self.dynamic_comps:
self.info(f"update dynamics")
self.system.setup_global_dynamics()
[docs]
def full_refresh(self, sesh=None):
"""a more time consuming but throughout refresh of the system"""
if self.log_level < 5:
self.info(f"full refresh")
check_dynamics = sesh.check_dynamics
sesh._num_refs = sesh.system.system_references(
numeric_only=True,
none_ok=True,
only_inst=False,
ignore_none_comp=False,
recache=True,
)
sesh._sys_refs = sesh.system.solver_vars(
check_dynamics=check_dynamics,
addable=sesh._num_refs,
**sesh._slv_kw,
)
sesh.update_methods(sesh=sesh)
[docs]
def min_refresh(self, sesh=None):
"""what things need to be refreshed per execution, this is important whenever items are replaced"""
# TODO: replace this function with an event based responsiblity model.
sesh = sesh if sesh is not None else self.sesh
if self.log_level < 5:
self.info(f"min refresh")
if sesh.full_update:
# TODO: dont require this
sesh.full_refresh(sesh=sesh)
# final ref's after update
# after updates
sesh._all_refs = sesh.system.system_references(
recache=True,
check_config=False,
ignore_none_comp=False,
none_ok=True,
only_inst=False,
)
# sesh._attr_sys_key_map = sesh.attribute_sys_key_map
# Problem Variable Definitions
sesh.Xref = sesh.all_problem_vars
sesh.Yref = sesh.sys_solver_objectives()
cons = {} # TODO: parse additional constraints
sesh.constraints = sesh.sys_solver_constraints(cons)
[docs]
def print_all_info(self, keys: str = None, comps: str = None):
"""
Print all the information of each component's dictionary.
Parameters:
key_sch (str, optional): A pattern to match dictionary keys. Only keys matching this pattern will be included in the output.
comps (list, optional): A list of component sys names to filter. Only information of these components will be printed.
Returns: None (except stdout :)
"""
from pprint import pprint
keys = keys.split(",")
comps = (comps + ",").split(",") # always top level
print(f"CONTEXT: {self}")
mtch = lambda key, ptrns: any(
[fnmatch.fnmatch(key.lower(), ptn.lower()) for ptn in ptrns]
)
# check your comps
itrs = self.all_comps.copy()
itrs[""] = Ref(self.system, "", True, False)
# check your comps
for cn, comp in itrs.items():
if comps is not None and not mtch(cn, comps):
continue
dct = comp.value().as_dict
if keys: # filter keys
dct = {k: v for k, v in dct.items() if mtch(k, keys)}
if dct:
print(f'INFO: {cn if cn else "<problem.system>"}')
pprint(dct)
print("-" * 80)
@property
def check_dynamics(self):
sesh = self.sesh
return sesh._dxdt is not None and sesh._dxdt is not False
# Context Manager Interface
def __enter__(self):
# Set the new state
if self.entered:
# TODO: enable env-var STRICT MODE to fail on things like this
self.warning(f"context already entered!")
elif self.log_level < 10:
self.debug(
f"enter context: {self.level_name} {self._dxdt} {self.dynamics_updated}"
)
# Important managed updates / refs from Xnew input
self.activate_temp_state()
self.entered = True
# signals / updates
if self.pre_exec:
self.pre_execute()
# TODO: create a component-slot ref-update graph, and update the system references accordingly.
# TODO: map the signals to the system references, and update the system references accordingly.
# TODO:
# transients wont update components/ methods dynamically (or shouldn't) so we can just update the system references once and be done with it for other cases, but that is not necessary unless a component changes or a component has in general a unique reference update system (economics / component-iterators)
sesh = self.sesh
if not sesh._dxdt is True and self.enter_refresh:
sesh.min_refresh(sesh=sesh)
elif sesh.dynamics_updated:
sesh.update_dynamics(sesh=sesh)
# Check for existing session
if sesh not in [None, self]:
self.msg(f"entering existing execution context")
if not isinstance(self, self.class_cache):
self.warning(f"change of execution class!")
# global level number
self.class_cache.level_number += 1
self.class_cache.session._prob_levels[self.level_name] = self
return self.class_cache.session # appear as top
# return New
self.class_cache.session = self
self.class_cache.level_number = 0
if self.log_level < 10:
refs = {k: v for k, v in self.sesh._sys_refs.get("attrs", {}).items() if v}
self.debug(
f"creating execution context for {self.system}| {self._slv_kw}| {refs}"
)
return self # return the local problem context, use self.sesh to get top values
def __exit__(self, exc_type, exc_value, traceback):
# define exit action, to handle the error here return True. Otherwise error propigates up to top level
self.exited = True
if self.log_level < 10:
self.debug(f"exit action {exc_type} {exc_value}")
# Last opprotunity to update the system at tsate
if self.post_exec:
# a component cutsom callback + signals
self.post_execute()
if self.post_callback:
# a context custom callback
self.post_callback()
# save state to dataframe
if self.save_on_exit:
self.save_data()
# sesh = self.sesh #this should be here
# if self.level_name in sesh._prob_levels:
# sesh._prob_levels.pop(self.level_name)
# Exit Scenerio (boolean return important for context manager exit handling in heirarchy)
if isinstance(exc_value, ProblemExit):
if self.log_level < 7:
self.debug(f"exit action {exc_type}| {exc_value.__dict__}")
# first things first
if exc_value.revert:
self.revert_to_start()
if self.pre_exec:
self.pre_execute()
lvl_match = False
# Decide our exit conditon (if we should exit)
if isinstance(exc_value, ProblemExitAtLevel):
# should we stop?
lvl_match = exc_value.level == self.level_name
if lvl_match:
if self.log_level <= 11:
self.debug(f"exit at level {exc_value}")
ext = True
else:
if self.log_level <= 5:
self.msg(f"exit not at level {exc_value}")
ext = False
# Check if we missed a level name and its the top level, if so then we raise a real error!
# always exit with level_name='top' at outer context
if (
not ext
and self.class_cache.session is self
and exc_value.level == "top"
):
if self.log_level <= 11:
self.debug(f"exit at top")
ext = True # top override
elif self.class_cache.session is self and not ext:
# never ever leave the top level without deleting the session
self.class_cache.level_number = 0
if type(self.problems_dict) is not dict:
self.problems_dict.pop(self._problem_id, None)
del self.class_cache.session
raise KeyError(f"cant exit to level! {exc_value.level} not found!!")
else:
if self.log_level <= 18:
self.info(f"problem exit revert={exc_value.revert}")
ext = True # basic exit is one level up
self.clean_context()
if type(self.problems_dict) is not dict:
self.problems_dict.pop(self._problem_id, None)
return ext
# default exit scenerios
elif exc_type is not None:
ext = self.error_action(exc_value)
else:
ext = self.exit_action()
self.clean_context()
if type(self.problems_dict) is not dict:
self.problems_dict.pop(self._problem_id, None)
return ext
[docs]
def debug_levels(self):
"""debug the levels of the context"""
if hasattr(self.class_cache, "session"):
for k, v in self.class_cache.session._prob_levels.items():
self.info(f"level: {k} | {v} | {v.x_start}")
else:
raise IllegalArgument(f"no session available")
# Multi Context Exiting:
# TODO: rethink this
[docs]
def persist_contexts(self):
"""convert all contexts to a new storage format"""
self.info(f"persisting contexts!")
current_problems = self.problems_dict
ProblemExec.problems_dict = {}
for k, v in current_problems.items():
self.problems_dict[k] = v # you will go on!
[docs]
def discard_contexts(self):
"""discard all contexts"""
current_problems = self.problems_dict
ProblemExec.problems_dict = weakref.WeakValueDictionary()
for k, v in current_problems.items():
ProblemExec.problems_dict[k] = v # you will go on!
[docs]
def reset_contexts(self, fail_if_discardmode=True):
"""reset all contexts to a new storage format"""
if isinstance(self.problems_dict, dict):
ProblemExec.problems_dict = {}
elif fail_if_discardmode:
raise IllegalArgument(
f"cant reset contexts! {self.problems_dict} while not in persistance mode"
)
def exit_with_state(self):
raise ProblemExit(self, revert=False)
def exit_and_revert(self):
raise ProblemExit(self, revert=True)
def exit_to_level(self, level: str, revert=False):
raise ProblemExitAtLevel(self, level=level, revert=revert)
[docs]
def exit_action(self):
"""handles the exit action wrt system"""
EOL = self.class_cache.session is self or self.level_name == "top"
if self.revert_last and EOL:
if self.log_level <= 8:
self.debug(f"revert last!")
self.debug(f"revert to{self.x_start}")
self.revert_to_start()
# run execute
if self.pre_exec:
self.pre_execute()
elif self.revert_every:
if self.log_level <= 8:
self.debug(f"revert to{self.x_start}")
self.revert_to_start()
# run execute
if self.pre_exec:
self.pre_execute()
# TODO: add exit on success option
return True # continue as normal
[docs]
def error_action(self, error):
"""handles the error action wrt to the problem"""
if self.log_level <= 11:
self.debug(f" with input: {self.kwargs}")
if self.fail_revert:
self.revert_to_start()
if self.exit_on_failure:
self.error(error, f"error in execution context")
return False # send me up
else:
self.warning(f"error in execution context: {error}")
return True # our problem will go on
[docs]
def save_data(self, index=None, force=False, **add_data):
"""save data to the context"""
sesh = self.sesh
if not self.exited and self.post_exec:
# a context custom callback
sesh.post_execute()
if force or not sesh.data or sesh.system.anything_changed:
out = sesh.output_state
if index is None and sesh._dxdt == True: # integration
index = sesh._time
elif index is None:
index = sesh._index
if add_data:
out.update(add_data)
if sesh._dxdt == True:
out["time"] = sesh._time
out["index"] = index
sesh._data[index] = out
# if we are integrating, then we dont increment the index
if sesh._dxdt != True:
sesh._index += 1
# reset the data for changed items
sesh.system._anything_changed = False
self.debug(f"data saved = {index}")
elif self.log_level < 15:
self.warning(f"no data saved, nothing changed")
def clean_context(self):
if hasattr(self.class_cache, "session") and self.class_cache.session is self:
if self.log_level <= 8:
self.debug(f"closing execution session")
self.class_cache.level_number = 0
del self.class_cache.session
elif hasattr(self.class_cache, "session"):
# global level number
self.class_cache.level_number -= 1
# if we are the top level, then we mark the session runtime/messages
if self.session_id == True:
self._run_end = datetime.datetime.now()
self._run_time = self._run_end - self._run_start
if self.log_level <= 10:
self.debug(
f"EXIT[{self.system.identity}] run time: {self._run_time}",
lvl=5,
)
# time context
def set_time(self, time, dt):
self._last_time = lt = self._time
self._time = time
dt_calc = time - lt
self._dt = dt if dt_calc <= 0 else dt_calc
# self.system.set_time(time) #system times / subcomponents too
def integrate(self, endtime, dt=0.001, max_step_dt=0.01, X0=None, **kw):
# Unpack Transient Problem
sesh = self.sesh
intl_refs = sesh.integrator_var_refs # order forms problem basis
sesh.prv_ingtegral_refs = intl_refs # for rate function
refs = sesh._sys_refs
system = sesh.system
min_kw = sesh._minimizer_kw
if min_kw is None:
min_kw = {}
if dt > max_step_dt:
self.warning(f"dt {dt} > max_step_dt {max_step_dt}!")
dt = max_step_dt
if self.log_level < 15:
self.info(f"simulating {system},{sesh}| int:{intl_refs} | refs: {refs}")
if not intl_refs:
raise Exception(f"no transient parameters found")
x_cur = {k: v.value(v.comp, sesh) for k, v in intl_refs.items()}
if self.log_level < 10:
self.debug(f"initial state {X0} {intl_refs}| {refs}")
if X0 is None:
# get current
X0 = x_cur
# add any missing solver vars existing in the system
if set(X0) != set(x_cur):
X0 = x_cur.update(X0)
# this will fail if X0 doesn't have solver vars!
X0 = np.array([X0[p] for p in intl_refs])
Time = np.arange(sesh.system.time, endtime + dt, dt)
rate_kw = {"min_kw": min_kw, "dt": dt}
# get the probelem variables
Xss = sesh.problem_opt_vars
Yobj = sesh.final_objectives
# run the simulation from the current state to the endtime
ans = solve_ivp(
sesh.integral_rate,
[sesh.system.time, endtime],
X0,
method="RK45",
t_eval=Time,
max_step=max_step_dt,
args=(dt, Xss, Yobj),
**kw,
)
print(ans)
return ans
[docs]
def integral_rate(self, t, x, dt, Xss=None, Yobj=None, **kw):
"""provides the dynamic rate of the system at time t, and state x"""
sesh = self.sesh
intl_refs = sesh.prv_ingtegral_refs # cached in self.integral()
refs = sesh._sys_refs
system = sesh.system
out = {p: np.nan for p in intl_refs}
Xin = {p: x[i] for i, p in enumerate(intl_refs)}
if self.log_level < 10:
self.info(f"sim_iter {t} {x} {Xin}")
with ProblemExec(
system,
level_name="tr_slvr",
Xnew=Xin,
revert_last=False,
revert_every=False,
dxdt=True,
) as pbx:
# test for record time
self.set_time(t, dt)
# save data at the start
pbx.save_data() # TODO: check_enable/ rate_check
# ad hoc time integration
for name, trdct in pbx.integrators.items():
if self.log_level <= 10:
self.info(
f"updating {trdct.var}|{trdct.var_ref.value(self.system,self)}<-{trdct.rate}|{trdct.current_rate}|{trdct.rate_ref.value(trdct.comp,system.last_context)}"
)
print(getattr(self.system, trdct.var, None))
print(getattr(self.system, trdct.rate, None))
out[trdct.var] = trdct.current_rate
# dynamics
for compnm, compdict in pbx.dynamic_comps.items():
comp = compdict # ["comp"]
if not comp.dynamic_state_vars and not comp.dynamic_input_vars:
continue # nothing to do...
Xds = np.array([r.value() for r in comp.Xt_ref.values()])
Uds = np.array([r.value() for r in comp.Ut_ref.values()])
# time updated in step
# system.info(f'comp {comp} {compnm} {Xds} {Uds}')
dxdt = comp.step(t, dt, Xds, Uds, True)
for i, (p, ref) in enumerate(comp.Xt_ref.items()):
out[(f"{compnm}." if compnm else "") + p] = dxdt[i]
# solvers
if self.run_solver and Xss and Yobj and self.solveable:
# TODO: add in any transient
with ProblemExec(
system,
level_name="ss_slvr",
revert_last=False,
revert_every=False,
dxdt=True,
) as pbx:
ss_out = pbx.solve_min(Xss, Yobj, **self._minimizer_kw)
if ss_out["ans"].success:
if self.log_level <= 9:
self.info(
f'exiting solver {t} {ss_out["Xans"]} {ss_out["Xstart"]}'
)
pbx.set_ref_values(ss_out["Xans"], scope="intgrl")
pbx.exit_to_level("ss_slvr", False)
else:
self.warning(
f'solver failed to converge {ss_out["ans"].message} {ss_out["Xans"]} {ss_out["X0"]}'
)
if pbx.opt_fail:
pbx.exit_to_level("sim", pbx.fail_revert)
else:
pbx.exit_to_level("ss_slvr", pbx.fail_revert)
V_dxdt = np.array([out[p] for p in intl_refs])
if self.log_level <= 10:
self.info(f"exiting transient {t} {V_dxdt} {Xin}")
pbx.exit_to_level("tr_slvr", False)
if any(np.isnan(V_dxdt)):
self.warning(f"solver got infeasible: {V_dxdt}|{Xin}")
pbx.exit_and_revert()
# TODO: handle this better, seems to cause a warning
raise ValueError(f"infeasible! nan result {V_dxdt} {out} {Xin}")
elif self.log_level <= 5:
self.debug(f"rate {self._dt} {t:5.3f}| {x}<-{V_dxdt} {Xin}")
return V_dxdt
[docs]
def solve_min(self, Xref=None, Yref=None, output=None, **kw):
"""
Solve the minimization problem using the given vars and constraints. And sets the system state to the solution depending on input of the following:
Solve the root problem using the given vars.
:param Xref: The reference input values.
:param Yref: The reference objective values to minimize.
:param output: The output dictionary to store the results. (default: None)
:param fail: Flag indicating whether to raise an exception if the solver doesn't converge. (default: True)
:param kw: Additional keyword arguments.
:return: The output dictionary containing the results.
"""
sesh = self.sesh
if Xref is None:
Xref = sesh.Xref
if Yref is None:
Yref = sesh.final_objectives
thresh = kw.pop("thresh", sesh.success_thresh)
# TODO: options for solver detail in response
dflt = {
"Xstart": Ref.refset_get(Xref, prob=sesh),
"Ystart": Ref.refset_get(Yref, prob=sesh),
"Xans": None,
"success": None,
"Xans": None,
"Yobj": None,
"Ycon": None,
"ans": None,
"weights": sesh._weights,
"constraints": sesh.constraints,
}
if output:
dflt.update(output)
output = dflt
else:
output = dflt
if len(Xref) == 0:
self.debug(f"no variables found for solver: {kw}")
# None for `ans` will not trigger optimization failure
return output
# override constraints input
kw.update(sesh.constraints)
if len(kw["bounds"]) != len(Xref):
raise ValueError(
f"bounds {len(sesh.constraints['bounds'])} != Xref {len(Xref)}"
)
if self.log_level < 10:
self.debug(f"minimize {Xref} {Yref} {kw}")
if sesh._weights is not None:
kw["weights"] = sesh._weights
sesh._ans = refmin_solve(sesh.system, self, Xref, Yref, **kw)
output["ans"] = sesh._ans
sesh.handle_solution(sesh._ans, Xref, Yref, output)
return output
def handle_solution(self, answer, Xref, Yref, output):
# TODO: move exit condition handiling somewhere else, reduce cross over from process_ans
sesh = self.sesh
if self.log_level < 10:
self.info(f"handiling solution: {answer}")
thresh = sesh.success_thresh
vars = list(Xref)
# Output Results
Xa = {p: answer.x[i] for i, p in enumerate(vars)}
output["Xans"] = Xa
Ref.refset_input(Xref, Xa, scope="solvd")
Yout = {p: yit.value(yit.comp, self) for p, yit in Yref.items()}
output["Yobj"] = Yout
Ycon = {}
if sesh.constraints["constraints"]:
x_in = answer.x
for c, k in zip(sesh.constraints["constraints"], sesh.constraints["info"]):
cv = c["fun"](x_in, self, {})
Ycon[k] = cv
output["Ycon"] = Ycon
de = answer.fun
if answer.success and de < thresh if thresh else True:
sesh.system._converged = True # TODO: put in context
sesh._converged = True
output["success"] = True
elif answer.success:
# out of threshold condition
self.warning(
f"solver didnt meet threshold: {de} <? {thresh} ! {answer.x} -> residual: {answer.fun}"
)
sesh.system._converged = False
sesh._converged = False
output["success"] = False # only false with threshold
else:
sesh.system._converged = False
sesh._converged = False
if self.opt_fail:
raise Exception(f"solver didnt converge: {answer}")
else:
self.warning(f"solver didnt converge: {answer}")
output["success"] = False
return output
# Solver Parsing Methods
[docs]
def sys_solver_objectives(self, **kw):
"""gathers variables from solver vars, and attempts to locate any input_vars to add as well. use exclude_vars to eliminate a variable from the solver"""
sys_refs = self.sesh._sys_refs
# Convert result per kind of objective (min/max ect)
objs = sys_refs.get("attrs", {}).get("solver.obj", {})
return {k: v for k, v in objs.items()}
[docs]
def pos_obj(self, ref):
"""converts an objective to a positive value"""
def f(sys, prob):
return 1 + ref.value(sys, prob) ** 2
return ref.copy(key=f)
@property
def final_objectives(self) -> dict:
"""returns the final objective of the system, depending on mixed objective, equalities, and constraints"""
sesh = self.sesh
Yobj = sesh.problem_objs
Yeq = sesh.problem_eq
Xss = sesh.problem_opt_vars
if Yobj:
return Yobj # here is your application objective, sir
# now make up an objective
elif not Yobj and Yeq:
# TODO: handle case of Yineq == None with root solver
self.info(f"making Yobj from Yeq: {Yeq}")
Yobj = {k: self.pos_obj(v) for k, v in Yeq.items()}
elif not Yobj:
# minimize the product of all vars, so the smallest value is the best that satisfies all constraints
if self.session_id == True:
self.info(f"making Yobj from X: {Xss}")
def dflt(sys, prob) -> float:
out = 1
for k, v in prob.problem_opt_vars.items():
val = v.value(sys, prob)
# The code snippet is calculating the linear norm of positive
# values greater than 1 by adding the square root of the sum
# of 1 and the square of the value to the variable `out`. This
# operation is intended to apply a large penalty to positive
# values greater than 1.
out = (
out + (1 + val**2) ** 0.5
) # linear norm of positive values > 1 should be very large penalty
return 1
Yobj = {"smallness": Ref(sesh.system, dflt)}
return Yobj # our residual based objective
[docs]
def sys_solver_constraints(self, add_con=None, combo_filter=True, **kw):
"""formatted as arguments for the solver"""
from engforge.solver_utils import create_constraint
sesh = self.sesh
Xrefs = sesh.Xref
system = sesh.system
sys_refs = sesh._sys_refs
all_refz = sesh.ref_attrs
extra_kw = self.kwargs
# TODO: move to kwarg parsing on setup
deactivated = (
ext_str_list(extra_kw, "deactivate", [])
if "deactivate" in extra_kw and extra_kw["deactivate"]
else []
)
activated = (
ext_str_list(extra_kw, "activate", [])
if "activate" in extra_kw and extra_kw["activate"]
else []
)
slv_inst = sys_refs.get("type", {}).get("solver", {})
trv_inst = {v.var: v for v in sys_refs.get("type", {}).get("time", {}).values()}
sys_refs = sys_refs.get("attrs", {})
if add_con is None:
add_con = {}
# The official definition of X var order
Nstates = len(Xrefs)
Xvars = list(Xrefs) # get names of solvers + dynamics
# constraints lookup
bnd_list = [[None, None]] * Nstates
con_list = []
con_info = [] # names of constraints
constraints = {
"constraints": con_list,
"bounds": bnd_list,
"info": con_info,
}
if isinstance(add_con, dict):
# Remove None Values
nones = {k for k, v in add_con.items() if v is None}
for ki in nones:
constraints.pop(ki, None)
assert all(
[callable(v) for k, v in add_con.items()]
), f"all custom input for constraints must be callable with X as argument"
constraints["constraints"].extend(
[v for k, v in add_con.items() if v is not None]
)
if add_con is False:
constraints = {} # youre free!
return constraints
# Add Constraints
ex_arg = {"con_args": (), **kw}
# TODO: dynamic limits
# Establish Anonymous Problem Constraint Refs
for slvr, ref in sesh.problem_opt_vars.items():
assert not all(
(slvr in slv_inst, slvr in trv_inst)
), f"solver and integrator share parameter {slvr} "
if slvr in slv_inst:
slv = slv_inst[slvr]
slv_var = True # mark a static varible
elif slvr in trv_inst:
slv = trv_inst[slvr]
slv_var = False # a dynamic variable
else:
self.warning(f"no solver instance for {slvr} ")
continue
slv_constraints = slv.constraints
if log.log_level < 7:
self.debug(f"constraints {slvr} {slv_constraints}")
# combine the independent variables into one, but allow multiple dependent objectives for constraints/objectives
for ctype in slv_constraints:
cval = ctype["value"]
kind = ctype["type"]
var = ctype["var"]
if log.log_level < 3:
self.msg(f"const: {slvr} {ctype}")
if cval is not None and slvr in Xvars:
# Check for combos & activation
combos = None
if "combos" in ctype:
combos = ctype["combos"]
combo_var = ctype["combo_var"]
active = ctype.get("active", True)
in_activate = (
any([arg_var_compare(combo_var, v) for v in activated])
if activated
else False
)
in_deactivate = (
any([arg_var_compare(combo_var, v) for v in deactivated])
if deactivated
else False
)
if log.log_level <= 5:
self.debug(f"filter combo: {ctype}=?{extra_kw}")
# Check active or activated
if not active and not activated:
if log.log_level < 3:
self.msg(f"skip con: inactive {var} {slvr} {ctype}")
continue
elif not active and not in_activate:
if log.log_level < 3:
self.msg(f"skip con: inactive {var} {slvr} {ctype}")
continue
elif active and in_deactivate:
if log.log_level < 3:
self.msg(f"skip con: deactivated {var} {slvr} ")
continue
if combos and combo_filter:
filt = filter_combos(combo_var, slv, extra_kw, combos)
if not filt:
if log.log_level < 5:
self.debug(
f"filtering constraint={filt} {var} |{combos}"
)
continue
if log.log_level < 10:
self.debug(f"adding var constraint {var,slvr,ctype,combos}")
# get the index of the variable
x_inx = Xvars.index(slvr)
# lookup rates for overlapping dynamic variables
rate_val = None
if sesh._dxdt is not None:
if isinstance(sesh._dxdt, dict) and not slv_var:
rate_val = sesh._dxdt.get(slvr, 0)
elif not slv_var:
rate_val = 0
# add the dynamic parameters when configured
if not slv_var and rate_val is not None:
# print(f'adding dynamic constraint {slvr} {rate_val}')
# if kind in ('min','max') and slvr in Xvars:
varref = Xrefs[slvr]
# varref = slv.rate_ref
# Ref Case
ccst = ref_to_val_constraint(
system,
varref.comp,
system.last_context,
Xrefs,
varref,
kind,
rate_val,
**kw,
)
# con_list.append(ccst)
con_info.append(
f"dxdt_{varref.comp.classname}.{slvr}_{kind}_{cval}"
)
con_list.append(ccst)
# elif slv_var:
elif slv_var:
# establish simple bounds w/ solver
if (
kind in ("min", "max")
and slvr in Xvars
and isinstance(cval, (int, float))
):
minv, maxv = bnd_list[x_inx]
bnd_list[x_inx] = [
cval if kind == "min" else minv,
cval if kind == "max" else maxv,
]
# add the bias of cval to the objective function
elif kind in ("min", "max") and slvr in Xvars:
varref = Xrefs[slvr]
# Ref Case
ccst = ref_to_val_constraint(
system,
varref.comp,
system.last_context,
Xrefs,
varref,
kind,
cval,
**kw,
)
con_info.append(f"val_{ref.comp.classname}_{kind}_{slvr}")
con_list.append(ccst)
else:
self.warning(
f"bad constraint: {cval} {kind} {slv_var}|{slvr}"
)
# Add Constraints
for slvr, ref in self.problem_ineq.items():
slv = slv_inst[slvr]
slv_constraints = slv.constraints
parent = self.get_parent_key(slvr, look_back_num=2) # get the parent comp
for ctype in slv_constraints:
cval = ctype["value"]
kind = ctype["type"]
if cval is not None:
name = f"ineq_{parent}{ref.comp.classname}.{slvr}_{kind}_{cval}"
if log.log_level < 5:
self.debug(f"filtering constraint {slvr} |{name}")
con_info.append(name)
con_list.append(
create_constraint(
system,
ref.comp,
Xrefs,
"ineq",
cval,
system.last_context,
**kw,
)
)
for slvr, ref in self.problem_eq.items():
parent = self.get_parent_key(slvr, look_back_num=2) # get the parent comp
if slvr in slv_inst and slvr in all_refz.get("solver.eq", {}):
slv = slv_inst[slvr]
slv_constraints = slv.constraints
for ctype in slv_constraints:
cval = ctype["value"]
kind = ctype["type"]
if cval is not None:
name = f"eq_{parent}{ref.comp.classname}.{slvr}_{kind}_{cval}"
if log.log_level < 5:
self.debug(f"filtering constraint {slvr} |{name}")
con_info.append(name)
con_info.append(
f"eq_{parent}{ref.comp.classname}.{slvr}_{kind}_{cval}"
)
con_list.append(
create_constraint(
system,
ref.comp,
Xrefs,
"eq",
cval,
system.last_context,
**kw,
)
)
else:
# This must be a dynamic rate
self.debug(f"dynamic rate eq {slvr} ")
con_info.append(f"eq_{parent}{ref.comp.classname}.{slvr}_rate")
con_list.append(
create_constraint(
system,
ref.comp,
Xrefs,
"eq",
ref,
system.last_context,
**kw,
)
)
return constraints
# General method to distribute input to internal components
[docs]
@classmethod
def parse_default(self, key, defaults, input_dict, rmv=False, empty_str=True):
"""splits strings or lists and returns a list of options for the key, if nothing found returns None if fail set to True raises an exception, otherwise returns the default value"""
if key in input_dict:
# kwargs will no longer have key!
if not rmv:
option = input_dict.get(key)
else:
option = input_dict.pop(key)
# print(f'removing option {key} {option}')
if option is None:
return option, False
elif isinstance(option, (int, float, bool)):
return option, False
elif isinstance(option, str):
if not empty_str and not option:
return None, False
option = option.split(",")
return option, False
elif key in defaults:
return defaults[key], True
return None, None
# State Interfaces
@property
def record_state(self) -> dict:
"""records the state of the system using session"""
# refs = self.all_variable_refs
sesh = self.sesh
# only get used refs modified no need for properties
# TODO: more elegant solution
chk = self.temp_state if self.temp_state else {}
# FIXME: only record state as it changes
# refs = {k:v for k,v in sesh.all_comps_and_vars.items()} # if k in chk
refs = {k: v for k, v in sesh.all_comps_and_vars.items()}
return Ref.refset_get(refs, sys=sesh.system, prob=self)
@property
def output_state(self) -> dict:
"""records the state of the system"""
sesh = self.sesh
# TODO: add system_properties to num_refs / all_system_refs ect.
if "nums" == sesh.save_mode:
refs = sesh.num_refs
elif "all" == sesh.save_mode:
refs = sesh.all_system_references
elif "vars" == sesh.save_mode:
refs = self.all_variable_refs
elif "prob" == sesh.save_mode:
raise NotImplementedError(f"problem save mode not implemented")
else:
raise KeyError(f"unknown save mode {sesh.save_mode}, not in {save_modes}")
out = Ref.refset_get(refs, sys=sesh.system, prob=self)
# Integration
if sesh._dxdt == True:
out["time"] = sesh._time
return out
[docs]
def get_ref_values(self, refs=None):
"""returns the values of the refs"""
sesh = self.sesh
if refs is None:
refs = sesh.all_system_references
return Ref.refset_get(refs, sys=self.system, prob=self)
[docs]
def set_ref_values(self, values, refs=None, scope="sref"):
"""returns the values of the refs"""
# TODO: add checks for the refs
if refs is None:
sesh = self.sesh
refs = sesh.all_comps_and_vars
return Ref.refset_input(refs, values, scope=scope)
[docs]
def change_sys_var(self, key, value, refs=None, doset=True, attr_key_map=None):
"""use this function to change the value of a system var and update the start state, multiple uses in the same context will not change the record preserving the start value
:param key: a string corresponding to a ref, or an `attrs.Attribute` of one of the system or its component's.
"""
if self.log_level < 5:
self.msg(f"setting var: {key} <= {value}")
# if isinstance(key,attrs.Attribute):
# if attr_key_map is None:
# attr_key_map = self.attribute_sys_key_map
# key = attr_key_map[key] #change to system key format
if refs is None:
refs = self.sesh.all_comps_and_vars
if key in refs:
ref = refs[key]
if key not in self.x_start:
cur_value = ref.value()
self.x_start[key] = cur_value
if self.log_level < 5:
self.msg(f"setting var: {key} <= {value} from {cur_value}")
if doset:
ref.set_value(value)
elif isinstance(key, Ref):
ref = key
self.x_start[key] = key.value()
if doset:
ref.set_value(value)
[docs]
def set_checkpoint(self):
"""sets the checkpoint"""
self.x_start = self.record_state
if log.log_level <= 7:
self.debug(f"set checkpoint: {list(self.x_start.values())}")
def revert_to_start(self):
sesh = self.sesh
if log.log_level < 5:
xs = list(self.x_start.values())
rs = list(self.record_state.values())
self.debug(f"reverting to start: {xs} -> {rs}")
# TODO: STRICT MODE Fail for refset_input
Ref.refset_input(
sesh.all_comps_and_vars, self.x_start, fail=False, scope="rvtst"
)
def activate_temp_state(self, new_state=None):
# TODO: determine when components change, and update refs accordingly!
sesh = self.sesh
# TODO: STRICT MODE Fail for refset_input
if new_state:
if self.log_level < 3:
self.debug(f"new-state: {self.temp_state}")
Ref.refset_input(
sesh.all_comps_and_vars, new_state, fail=False, scope="ntemp"
)
elif self.temp_state:
if self.log_level < 3:
self.debug(f"act-state: {self.temp_state}")
Ref.refset_input(
sesh.all_comps_and_vars,
self.temp_state,
fail=False,
scope="atemp",
)
elif self.log_level < 3:
self.debug(f"no-state: {new_state}")
# initial establishment costs / ect
if sesh.pre_exec:
sesh.pre_execute()
# System Events
[docs]
def apply_pre_signals(self):
"""applies all pre signals"""
msg_lvl = self.log_level <= 2
if self.log_level < 5:
self.msg(f"applying pre signals", lvl=6)
for signame, sig in self.sesh.signals.items():
if sig.mode == "pre" or sig.mode == "both":
if msg_lvl:
self.msg(f"applying post signals: {signame}", lvl=3)
sig.apply()
[docs]
def apply_post_signals(self):
"""applies all post signals"""
msg_lvl = self.log_level <= 2
if self.log_level < 5:
self.msg(f"applying post signals", lvl=6)
for signame, sig in self.sesh.signals.items():
if sig.mode == "post" or sig.mode == "both":
if msg_lvl:
self.msg(f"applying post signals: {signame}", lvl=3)
sig.apply()
[docs]
def update_system(self, *args, **kwargs):
"""updates the system"""
for ukey, uref in self.sesh._update_refs.items():
self.debug(f"context updating {ukey}")
uref.value(*args, **kwargs)
[docs]
def post_update_system(self, *args, **kwargs):
"""updates the system"""
for ukey, uref in self.sesh._post_update_refs.items():
self.debug(f"context post updating {ukey}")
uref.value(*args, **kwargs)
[docs]
def pre_execute(self, *args, **kwargs):
"""Updates the pre/both signals after the solver has been executed. This is useful for updating the system state after the solver has been executed."""
if log.log_level < 5:
self.msg(f"pre execute")
sesh = self.sesh
sesh.apply_pre_signals()
sesh.update_system(*args, **kwargs)
[docs]
def post_execute(self, *args, **kwargs):
"""Updates the post/both signals after the solver has been executed. This is useful for updating the system state after the solver has been executed."""
if log.log_level < 5:
self.msg(f"post execute")
sesh = self.sesh
sesh.apply_post_signals()
sesh.post_update_system(*args, **kwargs)
# Logging to class logger
@property
def identity(self):
return f"PROB|{str(self.session_id)[0:5]}"
@property
def log_level(self):
return log.log_level
def msg(self, msg, *a, **kw):
if log.log_level < 5:
log.msg(
f"{self.identity}|[{self.level_number}-{self.level_name}] {msg}",
*a,
**kw,
)
def debug(self, msg, *a, **kw):
if log.log_level <= 15:
log.debug(
f"{self.identity}|[{self.level_number}-{self.level_name}] {msg}",
*a,
**kw,
)
def warning(self, msg, *a, **kw):
log.warning(
f"{self.identity}|[{self.level_number}-{self.level_name}] {msg}",
*a,
**kw,
)
def info(self, msg, *a, **kw):
log.info(
f"{self.identity}|[{self.level_number}-{self.level_name}] {msg}",
*a,
**kw,
)
def error(self, error, msg, *a, **kw):
log.error(
error,
f"{self.identity}|[{self.level_number}-{self.level_name}] {msg}",
*a,
**kw,
)
def critical(self, msg, *a, **kw):
log.critical(
f"{self.identity}|[{self.level_number}-{self.level_name}] {msg}",
*a,
**kw,
)
# Safe Access Methods
@property
def ref_attrs(self):
return self.sesh._sys_refs.get("attrs", {}).copy()
@property
def attr_inst(self):
return self.sesh._sys_refs.get("type", {}).copy()
@property
def dynamic_comps(self):
return self.sesh._sys_refs.get("dynamic_comps", {}).copy()
# Instances
@property
def integrators(self):
return self.attr_inst.get("time", {}).copy()
@property
def signal_inst(self):
return self.attr_inst.get("signal", {}).copy()
@property
def solver_inst(self):
return self.attr_inst.get("solver", {}).copy()
@property
def kwargs(self):
"""copy of slv_kw args"""
return self.sesh._slv_kw.copy()
@property
def dynamic_state(self):
return self.ref_attrs.get("dynamics.state", {}).copy()
@property
def dynamic_rate(self):
return self.ref_attrs.get("dynamics.rate", {}).copy()
@property
def problem_input(self):
return self.ref_attrs.get("dynamics.input", {}).copy()
@property
def integrator_vars(self):
return self.ref_attrs.get("time.var", {}).copy()
@property
def integrator_rates(self):
return self.ref_attrs.get("time.rate", {}).copy()
# Y solver variables
@property
def problem_objs(self):
return self.ref_attrs.get("solver.obj", {}).copy()
@property
def problem_eq(self):
base_eq = self.ref_attrs.get("solver.eq", {}).copy()
if self._dxdt in [True, None] or self._dxdt is False:
return base_eq # wysiwyg
base_eq.update(self.dynamic_rate_eq)
if len(base_eq) > len(self.Xref):
self.warning(
f"problem_eq has more items than variables {base_eq} {self.Xref}"
)
# TODO: in this case combine the rate_eq to prevent this error
return base_eq
@property
def dynamic_rate_eq(self):
# Add Dynamics Rate Equalities
base_eq = {}
if self._dxdt in [True, None] or self._dxdt is False: # zero case
return base_eq
Xrefs = self.Xref
system = self.system
# henceforth, the rate shall be equal to somethign!
base_eq.update(self.dynamic_rate)
base_eq.update(self.filter_vars(self.integrator_rates))
out = {}
# TODO: handle callable case
if self._dxdt is not None and (
isinstance(self._dxdt, dict) or abs(self._dxdt) > 0
):
# add in the dynamic rate constraints
for var, varref in base_eq.items():
if isinstance(self._dxdt, dict) and var in self._dxdt:
rate_val = self._dxdt.get(var, 0) # zero default rate
elif isinstance(self._dxdt, (int, float)):
rate_val = self._dxdt
con_ref = ref_to_val_constraint(
system,
varref.comp,
system.last_context,
Xrefs,
varref,
"min",
rate_val,
return_ref=True,
)
out[var] = con_ref
else:
out[var] = varref # zip rate
else:
# thou shalt be zero (no ref modifications)
out = base_eq
return out
@property
def problem_ineq(self):
return self.ref_attrs.get("solver.ineq", {}).copy()
@property
def signals_source(self):
return self.ref_attrs.get("signal.source", {}).copy()
@property
def signals_target(self):
return self.ref_attrs.get("signal.target", {}).copy()
@property
def signals(self):
return self.ref_attrs.get("signal.signal", {}).copy()
# formatted output
@property
def is_active(self):
"""checks if the context has been entered and not exited"""
return self.entered and not self.exited
[docs]
@classmethod
def cls_is_active(cls):
"""checks if the cache has a session"""
if cls.class_cache and hasattr(cls.class_cache, "session"):
return True
return False
@property
def solveable(self):
"""checks the system's references to determine if its solveabl"""
if self.sesh.problem_opt_vars:
# TODO: expand this
return True
return False
@property
def integrator_rate_refs(self):
"""combine the dynamic state and the integrator rates to get the transient state of the system, but convert their keys to the target var names"""
dc = self.dynamic_state.copy()
for int_name, intinst in self.integrators.items():
if intinst.var in dc:
raise KeyError(
f"conflict with integrator name {intinst.var} and dynamic state"
)
dc.update({intinst.var: intinst.rate_ref})
return dc
@property
def integrator_var_refs(self):
"""combine the dynamic state and the integrator rates to get the transient state of the system, but convert their keys to the target var names"""
dc = self.dynamic_state.copy()
for int_name, intinst in self.integrators.items():
if intinst.var_ref in dc:
raise KeyError(
f"conflict with integrator name {intinst.var_ref} and dynamic state"
)
dc.update({intinst.var: intinst.var_ref})
return dc
# Dataframe support
@property
def numeric_data(self):
"""return a list of sorted data rows by item and filter each row to remove invalid data"""
sesh = self.sesh
filter_non_numeric = lambda kv: (
False if isinstance(kv[1], (list, dict, tuple)) else True
)
f_numrow = lambda in_dict: dict(filter(filter_non_numeric, in_dict.items()))
return [
f_numrow(kv[-1]) for kv in sorted(sesh.data.items(), key=lambda kv: kv[0])
]
@property
def dataframe(self) -> pd.DataFrame:
"""returns the dataframe of the system"""
res = pd.DataFrame(self.numeric_data)
self.system.format_columns(res)
return res
# TODO: expose optoin for saving all or part of the system information, for now lets default to all (saftey first, then performance :)
[docs]
def get_parent_key(self, key, look_back_num=1):
"""returns the parent key of the key"""
if not key:
return ""
elif key.count(".") >= look_back_num:
return ".".join(key.split(".")[:-look_back_num]) + "."
return ""
# Dynamics Interface
[docs]
def filter_vars(self, refs: list):
"""selects only settable refs"""
return {
f"{self.get_parent_key(k)}{v.key}": v
for k, v in refs.items()
if v.allow_set
}
# X solver variable refs
@property
def problem_opt_vars(self) -> dict:
"""solver variables"""
return self.ref_attrs.get("solver.var", {}).copy()
@property
def all_problem_vars(self) -> dict:
"""solver variables + dynamics states when dynamic_solve is True"""
varx = self.ref_attrs.get("solver.var", {}).copy()
# Add the dynamic states to be optimized (ignore if integrating)
sesh = self.sesh
if sesh.dynamic_solve and not sesh._dxdt is True:
varx.update(sesh.dynamic_state)
varx.update(self.filter_vars(sesh.integrator_vars))
return varx
@property
def dynamic_solve(self) -> bool:
"""indicates if the system is dynamic"""
sesh = self.sesh
dxdt = sesh._dxdt
if dxdt is None or dxdt is False:
return False
if dxdt is True:
return True
in_type = isinstance(dxdt, (dict, float, int))
bool_type = isinstance(dxdt, bool) and dxdt == True
if in_type or bool_type:
return True
return False
@property
def all_variable_refs(self) -> dict:
sesh = self.sesh
ing = self.integrator_vars
stt = self.dynamic_state
vars = self.problem_opt_vars
return {**ing, **stt, **vars}
@property
def all_variables(self) -> dict:
"""returns all variables in the system"""
return self.all_refs["attributes"]
@property
def all_comps(self) -> dict:
"""returns all variables in the system"""
return self.all_refs["components"]
@property
def all_components(self) -> dict:
"""returns all variables in the system"""
return self.all_refs["components"]
@property
def all_comps_and_vars(self) -> dict:
# TODO: ensure system refes are fresh per system runtime events
sesh = self.sesh
refs = sesh.all_refs
attrs = refs["attributes"].copy()
comps = refs["components"].copy()
attrs.update(comps)
return attrs
@property
def attribute_sys_key_map(self) -> dict:
"""returns an attribute:key mapping to lookup the key from the attribute"""
sesh = self.sesh
attrvars = sesh.all_refs["attributes"]
comps = {c: k for k, c in self.all_components.items()}
return {
refget_attr(v): refget_key(v, sesh.system, comps)
for k, v in attrvars.items()
}
@property
def all_system_references(self) -> dict:
sesh = self.sesh
refs = sesh.all_refs
out = {}
out.update(refs["attributes"])
out.update(refs["properties"])
return out
def __str__(self):
# TODO: expand this
return f"ProblemContext[{self.level_name:^12}][{str(self.session_id)[0:8]}-{str(self._problem_id)[0:8]}][{self.system.identity}]"
refget_attr = lambda ref: getattr(ref.comp.__class__.__attrs_attrs__, ref.key)
###FIXME: ref.comp needs to be reliablly in comps
refget_key = (
lambda ref, slf, comps: f'{comps[ref.comp]+"." if slf != ref.comp else ""}{ref.key}'
)
# TODO: move all system_reference concept inside problem context, remove from system/tabulation ect.
# TODO: use prob.register/change(comp,key='') to add components to the problem context, mapping subcomponents to the problem context
# TODO: make a graph of all problem dependencies and use that to determine the order of operations, and subsequent updates.
# subclass before altering please!
ProblemExec.class_cache = ProblemExec
[docs]
class Problem(ProblemExec, DataframeMixin):
# TODO: implement checks to ensure that problem is defined as the top level context to be returned to
# TODO: also define return options for data/system/dataframe and indexing
pass
@property
def level_name(self):
return "top" # fixed top output, garuntees exit to here.
@level_name.setter
def level_name(self, value):
raise AttributeError(f"cannot set level_name of top level problem context")