From 9df0d20c688fb4694708ed72bad5ad6233ed9469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sofus=20Albert=20H=C3=B8gsbro=20Rose?= Date: Thu, 2 May 2024 20:59:30 +0200 Subject: [PATCH] feat: Finished Gaussian Pulse node. Also fixed several bugs along the way. Full speed aheaad on the sources! --- TODO.md | 18 +- .../maxwell_sim_nodes/contracts/flow_kinds.py | 76 ++++++- .../maxwell_sim_nodes/contracts/node_types.py | 7 +- .../maxwell_sim_nodes/nodes/__init__.py | 6 +- .../nodes/analysis/math/map_math.py | 34 +-- .../maxwell_sim_nodes/nodes/analysis/viz.py | 101 +++++++-- .../inputs/constants/physical_constant.py | 25 ++- .../nodes/inputs/wave_constant.py | 4 +- .../nodes/sources/__init__.py | 23 ++- .../nodes/sources/temporal_shapes/__init__.py | 18 +- ...mporal_shape.py => expr_temporal_shape.py} | 0 .../gaussian_pulse_temporal_shape.py | 149 -------------- .../temporal_shapes/pulse_temporal_shape.py | 194 ++++++++++++++++++ ...mporal_shape.py => wave_temporal_shape.py} | 0 .../maxwell_sim_nodes/sockets/base.py | 22 +- .../maxwell_sim_nodes/sockets/expr.py | 65 ++++-- .../utils/extra_sympy_units.py | 31 +-- 17 files changed, 507 insertions(+), 266 deletions(-) rename src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/{array_temporal_shape.py => expr_temporal_shape.py} (100%) delete mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/gaussian_pulse_temporal_shape.py create mode 100644 src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/pulse_temporal_shape.py rename src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/{continuous_wave_temporal_shape.py => wave_temporal_shape.py} (100%) diff --git a/TODO.md b/TODO.md index 1ef2c5b..83987cf 100644 --- a/TODO.md +++ b/TODO.md @@ -57,6 +57,11 @@ +# VALIDATE +- [ ] Does the imaginary part of a complex phasor scale with the real part? Ex. when doing `V/m -> V/um` conversion, does the phase also scale by 1 million? + + + # Nodes ## Analysis @@ -144,8 +149,8 @@ ## Sources - [x] Temporal Shapes / Gaussian Pulse Temporal Shape - [x] Temporal Shapes / Continuous Wave Temporal Shape -- [ ] Temporal Shapes / Symbolic Temporal Shape - - [ ] Specify a Sympy function to generate appropriate array based on +- [ ] Temporal Shapes / Expr Temporal Shape + - [ ] Specify a Sympy function / data to generate appropriate array based on - [ ] Temporal Shapes / Data Temporal Shape - [x] Point Dipole Source @@ -585,4 +590,11 @@ Unreported: - [ ] Shader visualizations approximated from medium `nk` into a shader node graph, aka. a generic BSDF. -- [ ] Web importer that gets material data from refractiveindex.info. + + +- [ ] Easy conversion of lazyarrayrange to mu/sigma frequency for easy computation of pulse fits from data. + + +- [ ] IDEA: Hand-craft a faster `spu.convert_to`. + +- [ ] We should probably communicate with the `sympy` upstream about our deep usage of unit systems. They might be interested in the various workarounds :) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py index 8ba12f6..b342bf5 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/flow_kinds.py @@ -485,13 +485,13 @@ class LazyArrayRangeFlow: if isinstance(self.start, spux.SympyType): start_mathtype = spux.MathType.from_expr(self.start) else: - start_mathtype = spux.MathType.from_pytype(self.start) + start_mathtype = spux.MathType.from_pytype(type(self.start)) # Get Stop Mathtype if isinstance(self.stop, spux.SympyType): stop_mathtype = spux.MathType.from_expr(self.stop) else: - stop_mathtype = spux.MathType.from_pytype(self.stop) + stop_mathtype = spux.MathType.from_pytype(type(self.stop)) # Check Equal if start_mathtype != stop_mathtype: @@ -739,6 +739,10 @@ class LazyArrayRangeFlow: msg = f'Invalid kind: {kind}' raise TypeError(msg) + @functools.cached_property + def realize_array(self) -> ArrayFlow: + return self.realize() + #################### # - Params @@ -748,21 +752,52 @@ class ParamsFlow: func_args: list[spux.SympyExpr] = dataclasses.field(default_factory=list) func_kwargs: dict[str, spux.SympyExpr] = dataclasses.field(default_factory=dict) + symbols: frozenset[spux.Symbol] = frozenset() + + @functools.cached_property + def sorted_symbols(self) -> list[sp.Symbol]: + """Retrieves all symbols by concatenating int, real, and complex symbols, and sorting them by name. + + Returns: + All symbols valid for use in the expression. + """ + return sorted(self.symbols, key=lambda sym: sym.name) + #################### # - Scaled Func Args #################### - def scaled_func_args(self, unit_system: spux.UnitSystem): + def scaled_func_args( + self, + unit_system: spux.UnitSystem, + symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}), + ): """Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments.""" + if not all(sym in self.symbols for sym in symbol_values): + msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}" + raise ValueError(msg) + return [ - spux.convert_to_unit_system(func_arg, unit_system, use_jax_array=True) - for func_arg in self.func_args + spux.convert_to_unit_system(arg, unit_system, use_jax_array=True) + if arg not in symbol_values + else symbol_values[arg] + for arg in self.func_args ] - def scaled_func_kwargs(self, unit_system: spux.UnitSystem): + def scaled_func_kwargs( + self, + unit_system: spux.UnitSystem, + symbol_values: dict[spux.Symbol, spux.SympyExpr] = MappingProxyType({}), + ): """Return the function arguments, scaled to the unit system, stripped of units, and cast to jax-compatible arguments.""" + if not all(sym in self.symbols for sym in symbol_values): + msg = f"Symbols in {symbol_values} don't perfectly match the ParamsFlow symbols {self.symbols}" + raise ValueError(msg) + return { arg_name: spux.convert_to_unit_system(arg, unit_system, use_jax_array=True) - for arg_name, arg in self.func_args + if arg not in symbol_values + else symbol_values[arg] + for arg_name, arg in self.func_kwargs.items() } #################### @@ -780,16 +815,19 @@ class ParamsFlow: return ParamsFlow( func_args=self.func_args + other.func_args, func_kwargs=self.func_kwargs | other.func_kwargs, + symbols=self.symbols | other.symbols, ) def compose_within( self, enclosing_func_args: list[spux.SympyExpr] = (), enclosing_func_kwargs: dict[str, spux.SympyExpr] = MappingProxyType({}), + enclosing_symbols: frozenset[spux.Symbol] = frozenset(), ) -> typ.Self: return ParamsFlow( func_args=self.func_args + list(enclosing_func_args), func_kwargs=self.func_kwargs | dict(enclosing_func_kwargs), + symbols=self.symbols | enclosing_symbols, ) @@ -804,8 +842,6 @@ class InfoFlow: default_factory=dict ) ## TODO: Rename to dim_idxs - ## TODO: Add PhysicalType - @functools.cached_property def dim_lens(self) -> dict[str, int]: return {dim_name: len(dim_idx) for dim_name, dim_idx in self.dim_idx.items()} @@ -820,6 +856,13 @@ class InfoFlow: def dim_units(self) -> dict[str, spux.Unit]: return {dim_name: dim_idx.unit for dim_name, dim_idx in self.dim_idx.items()} + @functools.cached_property + def dim_physical_types(self) -> dict[str, spux.PhysicalType]: + return { + dim_name: spux.PhysicalType.from_unit(dim_idx.unit) + for dim_name, dim_idx in self.dim_idx.items() + } + @functools.cached_property def dim_idx_arrays(self) -> list[jax.Array]: return [ @@ -850,6 +893,21 @@ class InfoFlow: #################### # - Methods #################### + def rescale_dim_idxs(self, new_dim_idxs: dict[str, LazyArrayRangeFlow]) -> typ.Self: + return InfoFlow( + # Dimensions + dim_names=self.dim_names, + dim_idx={ + _dim_name: new_dim_idxs.get(_dim_name, dim_idx) + for _dim_name, dim_idx in self.dim_idx.items() + }, + # Outputs + output_name=self.output_name, + output_shape=self.output_shape, + output_mathtype=self.output_mathtype, + output_unit=self.output_unit, + ) + def delete_dimension(self, dim_name: str) -> typ.Self: """Delete a dimension.""" return InfoFlow( diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py index 73c4053..c5fe7b1 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/contracts/node_types.py @@ -43,10 +43,9 @@ class NodeType(blender_type_enum.BlenderTypeEnum): # Sources ## Sources / Temporal Shapes - GaussianPulseTemporalShape = enum.auto() - ContinuousWaveTemporalShape = enum.auto() - SymbolicTemporalShape = enum.auto() - DataTemporalShape = enum.auto() + PulseTemporalShape = enum.auto() + WaveTemporalShape = enum.auto() + ExprTemporalShape = enum.auto() ## Sources / PointDipoleSource = enum.auto() PlaneWaveSource = enum.auto() diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/__init__.py index 739465e..dd3ed23 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/__init__.py @@ -6,7 +6,7 @@ from . import ( monitors, outputs, # simulations, - # sources, + sources, # structures, # utilities, ) @@ -15,7 +15,7 @@ BL_REGISTER = [ *analysis.BL_REGISTER, *inputs.BL_REGISTER, *outputs.BL_REGISTER, - # *sources.BL_REGISTER, + *sources.BL_REGISTER, *mediums.BL_REGISTER, # *structures.BL_REGISTER, *bounds.BL_REGISTER, @@ -27,7 +27,7 @@ BL_NODES = { **analysis.BL_NODES, **inputs.BL_NODES, **outputs.BL_NODES, - # **sources.BL_NODES, + **sources.BL_NODES, **mediums.BL_NODES, # **structures.BL_NODES, **bounds.BL_NODES, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py index 2ff4423..a371fe9 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/math/map_math.py @@ -4,7 +4,6 @@ import enum import typing as typ import bpy -import jax import jax.numpy as jnp import sympy as sp @@ -138,6 +137,9 @@ class MapOperation(enum.StrEnum): @staticmethod def by_element_shape(shape: tuple[int, ...] | None) -> list[typ.Self]: MO = MapOperation + if shape == 'noshape': + return [] + # By Number if shape is None: return [ @@ -259,7 +261,7 @@ class MapOperation(enum.StrEnum): ), ## TODO: Matrix -> Vec ## TODO: Matrix -> Matrices - }.get(self, info) + }.get(self, info)() class MapMathNode(base.MaxwellSimNode): @@ -346,10 +348,10 @@ class MapMathNode(base.MaxwellSimNode): bl_label = 'Map Math' input_sockets: typ.ClassVar = { - 'Expr': sockets.ExprSocketDef(), + 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Array), } output_sockets: typ.ClassVar = { - 'Expr': sockets.ExprSocketDef(), + 'Expr': sockets.ExprSocketDef(active_kind=ct.FlowKind.Array), } #################### @@ -366,12 +368,12 @@ class MapMathNode(base.MaxwellSimNode): if has_info: return info.output_shape - return None + return 'noshape' output_shape: tuple[int, ...] | None = bl_cache.BLField(None) def search_operations(self) -> list[ct.BLEnumElement]: - if self.expr_output_shape is not None: + if self.expr_output_shape != 'noshape': return [ operation.bl_enum_element(i) for i, operation in enumerate( @@ -401,8 +403,8 @@ class MapMathNode(base.MaxwellSimNode): run_on_init=True, ) def on_input_changed(self): - # if self.operation not in MapOperation.by_element_shape(self.expr_output_shape): - self.operation = bl_cache.Signal.ResetEnumItems + if self.operation not in MapOperation.by_element_shape(self.expr_output_shape): + self.operation = bl_cache.Signal.ResetEnumItems @events.on_value_changed( # Trigger @@ -449,7 +451,7 @@ class MapMathNode(base.MaxwellSimNode): mapper = input_sockets['Mapper'] has_expr = not ct.FlowSignal.check(expr) - has_mapper = not ct.FlowSignal.check(expr) + has_mapper = not ct.FlowSignal.check(mapper) if has_expr and operation is not None: if not has_mapper: @@ -494,11 +496,11 @@ class MapMathNode(base.MaxwellSimNode): # - Compute Auxiliary: Info / Params #################### @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.Info, props={'active_socket_set', 'operation'}, - input_sockets={'Data'}, - input_socket_kinds={'Data': ct.FlowKind.Info}, + input_sockets={'Expr'}, + input_socket_kinds={'Expr': ct.FlowKind.Info}, ) def compute_data_info(self, props: dict, input_sockets: dict) -> ct.InfoFlow: operation = props['operation'] @@ -512,13 +514,13 @@ class MapMathNode(base.MaxwellSimNode): return ct.FlowSignal.FlowPending @events.computes_output_socket( - 'Data', + 'Expr', kind=ct.FlowKind.Params, - input_sockets={'Data'}, - input_socket_kinds={'Data': ct.FlowKind.Params}, + input_sockets={'Expr'}, + input_socket_kinds={'Expr': ct.FlowKind.Params}, ) def compute_data_params(self, input_sockets: dict) -> ct.ParamsFlow | ct.FlowSignal: - return input_sockets['Data'] + return input_sockets['Expr'] #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py index 21a9d05..3e055cd 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/analysis/viz.py @@ -2,8 +2,6 @@ import enum import typing as typ import bpy -import jax -import jax.numpy as jnp import jaxtyping as jtyp import matplotlib.axis as mpl_ax import sympy as sp @@ -196,6 +194,7 @@ class VizNode(base.MaxwellSimNode): #################### input_sockets: typ.ClassVar = { 'Expr': sockets.ExprSocketDef( + active_kind=ct.FlowKind.Array, symbols={_x := sp.Symbol('x', real=True)}, default_value=2 * _x, ), @@ -284,16 +283,57 @@ class VizNode(base.MaxwellSimNode): socket_name='Expr', input_sockets={'Expr'}, run_on_init=True, - input_socket_kinds={'Expr': ct.FlowKind.Info}, + input_socket_kinds={'Expr': {ct.FlowKind.Info, ct.FlowKind.Params}}, input_sockets_optional={'Expr': True}, ) def on_any_changed(self, input_sockets: dict): - if not ct.FlowSignal.check_single( - input_sockets['Expr'], ct.FlowSignal.FlowPending - ): + info = input_sockets['Expr'][ct.FlowKind.Info] + params = input_sockets['Expr'][ct.FlowKind.Params] + + has_info = not ct.FlowSignal.check(info) + has_params = not ct.FlowSignal.check(params) + + # Reset Viz Mode/Target + has_nonpending_info = not ct.FlowSignal.check_single( + info, ct.FlowSignal.FlowPending + ) + if has_nonpending_info: self.viz_mode = bl_cache.Signal.ResetEnumItems self.viz_target = bl_cache.Signal.ResetEnumItems + # Provide Sockets for Symbol Realization + ## -> This happens if Params contains not-yet-realized symbols. + if has_info and has_params and params.symbols: + if set(self.loose_input_sockets) != { + sym.name for sym in params.symbols if sym.name in info.dim_names + }: + self.loose_input_sockets = { + sym.name: sockets.ExprSocketDef( + active_kind=ct.FlowKind.LazyArrayRange, + shape=None, + mathtype=info.dim_mathtypes[sym.name], + physical_type=info.dim_physical_types[sym.name], + default_min=( + info.dim_idx[sym.name].start + if not sp.S(info.dim_idx[sym.name].start).is_infinite + else sp.S(0) + ), + default_max=( + info.dim_idx[sym.name].start + if not sp.S(info.dim_idx[sym.name].stop).is_infinite + else sp.S(1) + ), + default_steps=50, + ) + for sym in sorted( + params.symbols, key=lambda el: info.dim_names.index(el.name) + ) + if sym.name in info.dim_names + } + + elif self.loose_input_sockets: + self.loose_input_sockets = {} + @events.on_value_changed( prop_name='viz_mode', run_on_init=True, @@ -309,39 +349,62 @@ class VizNode(base.MaxwellSimNode): props={'viz_mode', 'viz_target', 'colormap'}, input_sockets={'Expr'}, input_socket_kinds={ - 'Expr': {ct.FlowKind.Array, ct.FlowKind.LazyValueFunc, ct.FlowKind.Info} + 'Expr': {ct.FlowKind.LazyValueFunc, ct.FlowKind.Info, ct.FlowKind.Params} }, + unit_systems={'BlenderUnits': ct.UNITS_BLENDER}, + all_loose_input_sockets=True, stop_propagation=True, ) def on_show_plot( - self, - managed_objs: dict, - input_sockets: dict, - props: dict, + self, managed_objs, props, input_sockets, loose_input_sockets, unit_systems ): # Retrieve Inputs - array_flow = input_sockets['Expr'][ct.FlowKind.Array] info = input_sockets['Expr'][ct.FlowKind.Info] + params = input_sockets['Expr'][ct.FlowKind.Params] + + has_info = not ct.FlowSignal.check(info) + has_params = not ct.FlowSignal.check(params) - # Check Flow if ( - any(ct.FlowSignal.check(inp) for inp in [array_flow, info]) + not has_info + or not has_params or props['viz_mode'] is None or props['viz_target'] is None ): return - # Viz Target + # Compute Data + lazy_value_func = input_sockets['Expr'][ct.FlowKind.LazyValueFunc] + symbol_values = ( + loose_input_sockets + if not params.symbols + else { + sym: loose_input_sockets[sym.name] + .realize_array.rescale_to_unit(info.dim_units[sym.name]) + .values + for sym in params.sorted_symbols + } + ) + data = lazy_value_func.func_jax( + *params.scaled_func_args( + unit_systems['BlenderUnits'], symbol_values=symbol_values + ), + **params.scaled_func_kwargs( + unit_systems['BlenderUnits'], symbol_values=symbol_values + ), + ) + if params.symbols: + info = info.rescale_dim_idxs(loose_input_sockets) + + # Visualize by-Target if props['viz_target'] == VizTarget.Plot2D: managed_objs['plot'].mpl_plot_to_image( - lambda ax: VizMode.to_plotter(props['viz_mode'])( - array_flow.values, info, ax - ), + lambda ax: VizMode.to_plotter(props['viz_mode'])(data, info, ax), bl_select=True, ) if props['viz_target'] == VizTarget.Pixels: managed_objs['plot'].map_2d_to_image( - array_flow.values, + data, colormap=props['colormap'], bl_select=True, ) diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/physical_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/physical_constant.py index fa3304d..2425ced 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/physical_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/constants/physical_constant.py @@ -66,6 +66,8 @@ class PhysicalConstantNode(base.MaxwellSimNode): # - UI #################### def draw_props(self, _, col: bpy.types.UILayout) -> None: + col.prop(self, self.blfields['physical_type'], text='') + row = col.row(align=True) row.prop(self, self.blfields['mathtype'], text='') row.prop(self, self.blfields['size'], text='') @@ -74,25 +76,34 @@ class PhysicalConstantNode(base.MaxwellSimNode): # - Events #################### @events.on_value_changed( - prop_name={'physical_type', 'mathtype', 'size'}, + # Trigger + prop_name={'physical_type'}, run_on_init=True, - props={'physical_type', 'mathtype', 'size'}, + # Loaded + props={'physical_type'}, ) - def on_mathtype_or_size_changed(self, props) -> None: + def on_physical_type_changed(self, props) -> None: """Change the input/output expression sockets to match the mathtype and size declared in the node.""" - shape = props['size'].shape - # Set Input Socket Physical Type if self.inputs['Value'].physical_type != props['physical_type']: self.inputs['Value'].physical_type = props['physical_type'] - self.search_mathtypes = bl_cache.Signal.ResetEnumItems - self.search_sizes = bl_cache.Signal.ResetEnumItems + self.mathtype = bl_cache.Signal.ResetEnumItems + self.size = bl_cache.Signal.ResetEnumItems + @events.on_value_changed( + # Trigger + prop_name={'mathtype', 'size'}, + run_on_init=True, + # Loaded + props={'physical_type', 'mathtype', 'size'}, + ) + def on_mathtype_or_size_changed(self, props) -> None: # Set Input Socket Math Type if self.inputs['Value'].mathtype != props['mathtype']: self.inputs['Value'].mathtype = props['mathtype'] # Set Input Socket Shape + shape = props['size'].shape if self.inputs['Value'].shape != shape: self.inputs['Value'].shape = shape diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py index 3fbe7a0..1bf60e6 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/inputs/wave_constant.py @@ -147,7 +147,9 @@ class WaveConstantNode(base.MaxwellSimNode): if has_freq: return input_sockets['Freq'] - return sci_constants.vac_speed_of_light / input_sockets['WL'] + return spu.convert_to( + sci_constants.vac_speed_of_light / input_sockets['WL'], spux.THz + ) @events.computes_output_socket( 'WL', diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/__init__.py index 05bcf13..5d14684 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/__init__.py @@ -1,24 +1,27 @@ -# from . import uniform_current_source -from . import plane_wave_source, point_dipole_source, temporal_shapes - -# from . import gaussian_beam_source -# from . import astigmatic_gaussian_beam_source -# from . import tfsf_source +from . import ( + # astigmatic_gaussian_beam_source, + # gaussian_beam_source, + # plane_wave_source, + # point_dipole_source, + temporal_shapes, + # tfsf_source, + # uniform_current_source, +) BL_REGISTER = [ *temporal_shapes.BL_REGISTER, - *point_dipole_source.BL_REGISTER, + #*point_dipole_source.BL_REGISTER, # *uniform_current_source.BL_REGISTER, - *plane_wave_source.BL_REGISTER, + #*plane_wave_source.BL_REGISTER, # *gaussian_beam_source.BL_REGISTER, # *astigmatic_gaussian_beam_source.BL_REGISTER, # *tfsf_source.BL_REGISTER, ] BL_NODES = { **temporal_shapes.BL_NODES, - **point_dipole_source.BL_NODES, + #**point_dipole_source.BL_NODES, # **uniform_current_source.BL_NODES, - **plane_wave_source.BL_NODES, + #**plane_wave_source.BL_NODES, # **gaussian_beam_source.BL_NODES, # **astigmatic_gaussian_beam_source.BL_NODES, # **tfsf_source.BL_NODES, diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/__init__.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/__init__.py index e908e5b..0f32b5d 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/__init__.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/__init__.py @@ -1,15 +1,13 @@ -from . import gaussian_pulse_temporal_shape - -# from . import continuous_wave_temporal_shape -# from . import array_temporal_shape +# from . import expr_temporal_shape, pulse_temporal_shape, wave_temporal_shape +from . import pulse_temporal_shape BL_REGISTER = [ - *gaussian_pulse_temporal_shape.BL_REGISTER, - # *continuous_wave_temporal_shape.BL_REGISTER, - # *array_temporal_shape.BL_REGISTER, + *pulse_temporal_shape.BL_REGISTER, + # *wave_temporal_shape.BL_REGISTER, + # *expr_temporal_shape.BL_REGISTER, ] BL_NODES = { - **gaussian_pulse_temporal_shape.BL_NODES, - # **continuous_wave_temporal_shape.BL_NODES, - # **array_temporal_shape.BL_NODES, + **pulse_temporal_shape.BL_NODES, + # **wave_temporal_shape.BL_NODES, + # **expr_temporal_shape.BL_NODES, } diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/array_temporal_shape.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/expr_temporal_shape.py similarity index 100% rename from src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/array_temporal_shape.py rename to src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/expr_temporal_shape.py diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/gaussian_pulse_temporal_shape.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/gaussian_pulse_temporal_shape.py deleted file mode 100644 index f563bb7..0000000 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/gaussian_pulse_temporal_shape.py +++ /dev/null @@ -1,149 +0,0 @@ - -import bpy -import numpy as np -import sympy.physics.units as spu -import tidy3d as td - -from blender_maxwell.utils import extra_sympy_units as spuex - -from .... import contracts as ct -from .... import managed_objs, sockets -from ... import base, events - - -class GaussianPulseTemporalShapeNode(base.MaxwellSimNode): - node_type = ct.NodeType.GaussianPulseTemporalShape - bl_label = 'Gaussian Pulse Temporal Shape' - # bl_icon = ... - - #################### - # - Sockets - #################### - input_sockets = { - # "amplitude": sockets.RealNumberSocketDef( - # label="Temporal Shape", - # ), ## Should have a unit of some kind... - 'Freq Center': sockets.PhysicalFreqSocketDef( - default_value=500 * spuex.terahertz, - ), - 'Freq Std.': sockets.PhysicalFreqSocketDef( - default_value=200 * spuex.terahertz, - ), - 'Phase': sockets.PhysicalAngleSocketDef(), - 'Delay rel. AngFreq': sockets.RealNumberSocketDef( - default_value=5.0, - ), - 'Remove DC': sockets.BoolSocketDef( - default_value=True, - ), - } - output_sockets = { - 'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(), - } - - managed_obj_types = { - 'amp_time': managed_objs.ManagedBLImage, - } - - #################### - # - Properties - #################### - plot_time_start: bpy.props.FloatProperty( - name='Plot Time Start (ps)', - description='The instance ID of a particular MaxwellSimNode instance, used to index caches', - default=0.0, - update=(lambda self, context: self.on_prop_changed('plot_time_start', context)), - ) - plot_time_end: bpy.props.FloatProperty( - name='Plot Time End (ps)', - description='The instance ID of a particular MaxwellSimNode instance, used to index caches', - default=5, - update=(lambda self, context: self.on_prop_changed('plot_time_start', context)), - ) - - #################### - # - UI - #################### - def draw_props(self, _, layout): - layout.label(text='Plot Settings') - split = layout.split(factor=0.6) - - col = split.column() - col.label(text='t-Range (ps)') - - col = split.column() - col.prop(self, 'plot_time_start', text='') - col.prop(self, 'plot_time_end', text='') - - #################### - # - Output Socket Computation - #################### - @events.computes_output_socket( - 'Temporal Shape', - input_sockets={ - 'Freq Center', - 'Freq Std.', - 'Phase', - 'Delay rel. AngFreq', - 'Remove DC', - }, - ) - def compute_source(self, input_sockets: dict) -> td.GaussianPulse: - if ( - (_freq_center := input_sockets['Freq Center']) is None - or (_freq_std := input_sockets['Freq Std.']) is None - or (_phase := input_sockets['Phase']) is None - or (time_delay_rel_ang_freq := input_sockets['Delay rel. AngFreq']) is None - or (remove_dc_component := input_sockets['Remove DC']) is None - ): - raise ValueError('Inputs not defined') - - cheating_amplitude = 1.0 - freq_center = spu.convert_to(_freq_center, spu.hertz) / spu.hertz - freq_std = spu.convert_to(_freq_std, spu.hertz) / spu.hertz - phase = spu.convert_to(_phase, spu.radian) / spu.radian - - return td.GaussianPulse( - amplitude=cheating_amplitude, - phase=phase, - freq0=freq_center, - fwidth=freq_std, - offset=time_delay_rel_ang_freq, - remove_dc_component=remove_dc_component, - ) - - @events.on_show_plot( - managed_objs={'amp_time'}, - props={'plot_time_start', 'plot_time_end'}, - output_sockets={'Temporal Shape'}, - stop_propagation=True, - ) - def on_show_plot( - self, - managed_objs: dict, - output_sockets: dict, - props: dict, - ): - temporal_shape = output_sockets['Temporal Shape'] - plot_time_start = props['plot_time_start'] * 1e-15 - plot_time_end = props['plot_time_end'] * 1e-15 - - times = np.linspace(plot_time_start, plot_time_end) - - managed_objs['amp_time'].mpl_plot_to_image( - lambda ax: temporal_shape.plot_spectrum(times, ax=ax), - bl_select=True, - ) - - -#################### -# - Blender Registration -#################### -BL_REGISTER = [ - GaussianPulseTemporalShapeNode, -] -BL_NODES = { - ct.NodeType.GaussianPulseTemporalShape: ( - ct.NodeCategory.MAXWELLSIM_SOURCES_TEMPORALSHAPES - ) -} diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/pulse_temporal_shape.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/pulse_temporal_shape.py new file mode 100644 index 0000000..267e6a1 --- /dev/null +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/pulse_temporal_shape.py @@ -0,0 +1,194 @@ +"""Implements the `PulseTemporalShapeNode`.""" + +import functools +import typing as typ + +import bpy +import jax.numpy as jnp +import sympy as sp +import sympy.physics.units as spu +import tidy3d as td + +from blender_maxwell.utils import extra_sympy_units as spux + +from .... import contracts as ct +from .... import managed_objs, sockets +from ... import base, events + + +def _manual_amp_time(self, time: float) -> complex: + """Copied implementation of `pulse.amp_time` for `tidy3d` temporal shapes, which replaces use of `numpy` with `jax.numpy` for `jit`-ability. + + Since the function is detached from the method, `self` is not implicitly available. It should be pre-defined from a real source time object using `functools.partial`, before `jax.jit`ing. + + ## License + **This function is directly copied from `tidy3d`**. + As such, it should be considered available under the `tidy3d` license (as of writing, LGPL 2.1): + + ## Reference + Permalink to GitHub source code: + """ + twidth = 1.0 / (2 * jnp.pi * self.fwidth) + omega0 = 2 * jnp.pi * self.freq0 + time_shifted = time - self.offset * twidth + + offset = jnp.exp(1j * self.phase) + oscillation = jnp.exp(-1j * omega0 * time) + amp = jnp.exp(-(time_shifted**2) / 2 / twidth**2) * self.amplitude + + pulse_amp = offset * oscillation * amp + + # subtract out DC component + if self.remove_dc_component: + pulse_amp = pulse_amp * (1j + time_shifted / twidth**2 / omega0) + else: + # 1j to make it agree in large omega0 limit + pulse_amp = pulse_amp * 1j + + return pulse_amp + + +class PulseTemporalShapeNode(base.MaxwellSimNode): + node_type = ct.NodeType.PulseTemporalShape + bl_label = 'Gaussian Pulse Temporal Shape' + + #################### + # - Sockets + #################### + input_sockets: typ.ClassVar = { + 'max E': sockets.ExprSocketDef( + mathtype=spux.MathType.Complex, + physical_type=spux.PhysicalType.EField, + default_value=1 + 0j, + ), + 'μ Freq': sockets.ExprSocketDef( + physical_type=spux.PhysicalType.Freq, + default_unit=spux.THz, + default_value=500, + ), + 'σ Freq': sockets.ExprSocketDef( + physical_type=spux.PhysicalType.Freq, + default_unit=spux.THz, + default_value=200, + ), + 'Offset Time': sockets.ExprSocketDef(default_value=5, abs_min=2.5), + 'Remove DC': sockets.BoolSocketDef( + default_value=True, + ), + } + output_sockets: typ.ClassVar = { + 'Temporal Shape': sockets.MaxwellTemporalShapeSocketDef(), + 'E(t)': sockets.ExprSocketDef(active_kind=ct.FlowKind.Array), + } + + managed_obj_types: typ.ClassVar = { + 'plot': managed_objs.ManagedBLImage, + } + + #################### + # - UI + #################### + def draw_info(self, _: bpy.types.Context, layout: bpy.types.UILayout) -> None: + box = layout.box() + row = box.row() + row.alignment = 'CENTER' + row.label(text='Parameter Scale') + + # Split + split = box.split(factor=0.3, align=False) + + ## LHS: Parameter Names + col = split.column() + col.alignment = 'RIGHT' + col.label(text='Off t:') + + ## RHS: Parameter Units + col = split.column() + col.label(text='1 / 2π·σ(𝑓)') + + #################### + # - FlowKind: Value + #################### + @events.computes_output_socket( + 'Temporal Shape', + input_sockets={ + 'max E', + 'μ Freq', + 'σ Freq', + 'Offset Time', + 'Remove DC', + }, + unit_systems={'Tidy3DUnits': ct.UNITS_TIDY3D}, + scale_input_sockets={ + 'max E': 'Tidy3DUnits', + 'μ Freq': 'Tidy3DUnits', + 'σ Freq': 'Tidy3DUnits', + }, + ) + def compute_temporal_shape(self, input_sockets, unit_systems) -> td.GaussianPulse: + return td.GaussianPulse( + amplitude=sp.re(input_sockets['max E']), + phase=sp.im(input_sockets['max E']), + freq0=input_sockets['μ Freq'], + fwidth=input_sockets['σ Freq'], + offset=input_sockets['Offset Time'], + remove_dc_component=input_sockets['Remove DC'], + ) + + #################### + # - FlowKind: LazyValueFunc / Info / Params + #################### + @events.computes_output_socket( + 'E(t)', + kind=ct.FlowKind.LazyValueFunc, + output_sockets={'Temporal Shape'}, + ) + def compute_time_to_efield_lazy(self, output_sockets) -> td.GaussianPulse: + temporal_shape = output_sockets['Temporal Shape'] + jax_amp_time = functools.partial(_manual_amp_time, temporal_shape) + + ## TODO: Don't just partial() it up, do it property in the ParamsFlow! + ## -> Right now it's recompiled every time. + + return ct.LazyValueFuncFlow( + func=jax_amp_time, + func_args=[spux.PhysicalType.Time], + supports_jax=True, + ) + + @events.computes_output_socket( + 'E(t)', + kind=ct.FlowKind.Info, + ) + def compute_time_to_efield_info(self) -> td.GaussianPulse: + return ct.InfoFlow( + dim_names=['t'], + dim_idx={ + 't': ct.LazyArrayRangeFlow( + start=sp.S(0), stop=sp.oo, steps=0, unit=spu.second + ) + }, + output_name='E', + output_shape=None, + output_mathtype=spux.MathType.Complex, + output_unit=spu.volt / spu.um, + ) + + @events.computes_output_socket( + 'E(t)', + kind=ct.FlowKind.Params, + ) + def compute_time_to_efield_params(self) -> td.GaussianPulse: + sym_time = sp.Symbol('t', real=True, nonnegative=True) + return ct.ParamsFlow(func_args=[sym_time], symbols={sym_time}) + + +#################### +# - Blender Registration +#################### +BL_REGISTER = [ + PulseTemporalShapeNode, +] +BL_NODES = { + ct.NodeType.PulseTemporalShape: (ct.NodeCategory.MAXWELLSIM_SOURCES_TEMPORALSHAPES) +} diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/continuous_wave_temporal_shape.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/wave_temporal_shape.py similarity index 100% rename from src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/continuous_wave_temporal_shape.py rename to src/blender_maxwell/node_trees/maxwell_sim_nodes/nodes/sources/temporal_shapes/wave_temporal_shape.py diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py index 6f8697a..8523da8 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/base.py @@ -217,9 +217,14 @@ class MaxwellSimSocket(bpy.types.NodeSocket): Called by `self.on_prop_changed()` when `self.active_kind` was changed. """ self.display_shape = ( - 'SQUARE' if self.active_kind == ct.FlowKind.LazyArrayRange else 'CIRCLE' - ) # + ('_DOT' if self.use_units else '') - ## TODO: Valid Active Kinds should be a subset/subenum(?) of FlowKind + 'SQUARE' + if self.active_kind == ct.FlowKind.LazyArrayRange + else ('DIAMOND' if self.active_kind == ct.FlowKind.Array else 'CIRCLE') + ) + ( + '_DOT' + if hasattr(self, 'physical_type') and self.physical_type is not None + else '' + ) def on_socket_prop_changed(self, prop_name: str) -> None: """Called when a property has been updated. @@ -811,6 +816,7 @@ class MaxwellSimSocket(bpy.types.NodeSocket): { ct.FlowKind.Value: self.draw_value, ct.FlowKind.LazyArrayRange: self.draw_lazy_array_range, + ct.FlowKind.Array: self.draw_array, }[self.active_kind](col) # Info Drawing @@ -930,6 +936,16 @@ class MaxwellSimSocket(bpy.types.NodeSocket): col: Target for defining UI elements. """ + def draw_array(self, col: bpy.types.UILayout) -> None: + """Draws the socket array UI on its own line. + + Notes: + Should be overriden by individual socket classes, if they have an editable `FlowKind.Array`. + + Parameters: + col: Target for defining UI elements. + """ + #################### # - UI Methods: Auxilliary #################### diff --git a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py index fd63b3f..1ef2552 100644 --- a/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py +++ b/src/blender_maxwell/node_trees/maxwell_sim_nodes/sockets/expr.py @@ -194,10 +194,19 @@ class ExprBLSocket(base.MaxwellSimSocket): current_value = self.value current_lazy_array_range = self.lazy_array_range - self.unit = bl_cache.Signal.InvalidateCache + # Old Unit Not in Physical Type + ## -> This happens when dynamically altering self.physical_type + if self.unit in self.physical_type.valid_units: + self.unit = bl_cache.Signal.InvalidateCache - self.value = current_value - self.lazy_array_range = current_lazy_array_range + self.value = current_value + self.lazy_array_range = current_lazy_array_range + else: + self.unit = bl_cache.Signal.InvalidateCache + + # Workaround: Manually Jiggle FlowKind Invalidation + self.value = self.value + self.lazy_value_range = self.lazy_value_range #################### # - Property Callback @@ -238,10 +247,25 @@ class ExprBLSocket(base.MaxwellSimSocket): return mathtype, shape - def _to_raw_value(self, expr: spux.SympyExpr): + def _to_raw_value(self, expr: spux.SympyExpr, force_complex: bool = False): if self.unit is not None: - return spux.sympy_to_python(spux.scale_to_unit(expr, self.unit)) - return spux.sympy_to_python(expr) + pyvalue = spux.sympy_to_python(spux.scale_to_unit(expr, self.unit)) + else: + pyvalue = spux.sympy_to_python(expr) + + # Cast complex -> tuple[float, float] + if isinstance(pyvalue, complex) or ( + isinstance(pyvalue, int | float) and force_complex + ): + return (pyvalue.real, pyvalue.imag) + if isinstance(pyvalue, tuple) and all( + isinstance(v, complex) + or (isinstance(pyvalue, int | float) and force_complex) + for v in pyvalue + ): + return tuple([(v.real, v.imag) for v in pyvalue]) + + return pyvalue def _parse_expr_str(self, expr_spstr: str) -> None: expr = sp.sympify( @@ -346,7 +370,9 @@ class ExprBLSocket(base.MaxwellSimSocket): elif self.mathtype == MT_R: self.raw_value_float = self._to_raw_value(expr) elif self.mathtype == MT_C: - self.raw_value_complex = self._to_raw_value(expr) + self.raw_value_complex = self._to_raw_value( + expr, force_complex=True + ) elif self.shape == (2,): if self.mathtype == MT_Z: self.raw_value_int2 = self._to_raw_value(expr) @@ -355,9 +381,10 @@ class ExprBLSocket(base.MaxwellSimSocket): elif self.mathtype == MT_R: self.raw_value_float2 = self._to_raw_value(expr) elif self.mathtype == MT_C: - self.raw_value_complex2 = self._to_raw_value(expr) + self.raw_value_complex2 = self._to_raw_value( + expr, force_complex=True + ) elif self.shape == (3,): - log.critical(expr) if self.mathtype == MT_Z: self.raw_value_int3 = self._to_raw_value(expr) elif self.mathtype == MT_Q: @@ -365,7 +392,9 @@ class ExprBLSocket(base.MaxwellSimSocket): elif self.mathtype == MT_R: self.raw_value_float3 = self._to_raw_value(expr) elif self.mathtype == MT_C: - self.raw_value_complex3 = self._to_raw_value(expr) + self.raw_value_complex3 = self._to_raw_value( + expr, force_complex=True + ) #################### # - FlowKind: LazyArrayRange @@ -451,7 +480,7 @@ class ExprBLSocket(base.MaxwellSimSocket): ] elif value.mathtype == MT_C: self.raw_range_complex = [ - self._to_raw_value(bound * unit) + self._to_raw_value(bound * unit, force_complex=True) for bound in [value.start, value.stop] ] @@ -674,7 +703,7 @@ class ExprBLSocket(base.MaxwellSimSocket): _row.label(text=text) def draw_info(self, info: ct.InfoFlow, col: bpy.types.UILayout) -> None: - if self.show_info_columns: + if self.active_kind == ct.FlowKind.Array and self.show_info_columns: row = col.row() box = row.box() grid = box.grid_flow( @@ -725,9 +754,9 @@ class ExprBLSocket(base.MaxwellSimSocket): #################### class ExprSocketDef(base.SocketDef): socket_type: ct.SocketType = ct.SocketType.Expr - active_kind: typ.Literal[ct.FlowKind.Value, ct.FlowKind.LazyArrayRange] = ( - ct.FlowKind.Value - ) + active_kind: typ.Literal[ + ct.FlowKind.Value, ct.FlowKind.LazyArrayRange, ct.FlowKind.Array + ] = ct.FlowKind.Value # Socket Interface ## TODO: __hash__ like socket method based on these? @@ -740,15 +769,15 @@ class ExprSocketDef(base.SocketDef): default_unit: spux.Unit | None = None # FlowKind: Value - default_value: spux.SympyExpr = sp.RealNumber(0) + default_value: spux.SympyExpr = sp.S(0) abs_min: spux.SympyExpr | None = None ## TODO: Not used (yet) abs_max: spux.SympyExpr | None = None ## TODO: Not used (yet) ## TODO: Idea is to use this scalar uniformly for all shape elements ## TODO: -> But we may want to **allow** using same-shape for diff. bounds. # FlowKind: LazyArrayRange - default_min: spux.SympyExpr = sp.RealNumber(0) - default_max: spux.SympyExpr = sp.RealNumber(1) + default_min: spux.SympyExpr = sp.S(0) + default_max: spux.SympyExpr = sp.S(1) default_steps: int = 2 ## TODO: Configure lin/log/... scaling (w/enumprop in UI) diff --git a/src/blender_maxwell/utils/extra_sympy_units.py b/src/blender_maxwell/utils/extra_sympy_units.py index 2909848..f2a6cf7 100644 --- a/src/blender_maxwell/utils/extra_sympy_units.py +++ b/src/blender_maxwell/utils/extra_sympy_units.py @@ -98,6 +98,12 @@ class MathType(enum.StrEnum): if sp_obj.is_complex: return MathType.Complex + # Infinities + if sp_obj in [sp.oo, -sp.oo]: + return MathType.Real ## TODO: Strictly, could be ex. integer... + if sp_obj in [sp.zoo, -sp.zoo]: + return MathType.Complex + msg = f"Can't determine MathType from sympy object: {sp_obj}" raise ValueError(msg) @@ -755,13 +761,11 @@ def scaling_factor(unit_from: spu.Quantity, unit_to: spu.Quantity) -> Number: raise ValueError(msg) -_UNIT_STR_MAP = {sym.name: unit for sym, unit in UNIT_BY_SYMBOL.items()} - - @functools.cache def unit_str_to_unit(unit_str: str) -> Unit | None: - if unit_str in _UNIT_STR_MAP: - return _UNIT_STR_MAP[unit_str] + expr = sp.sympify(unit_str).subs(UNIT_BY_SYMBOL) + if expr.has(spu.Quantity): + return expr msg = f'No valid unit for unit string {unit_str}' raise ValueError(msg) @@ -812,7 +816,6 @@ class PhysicalType(enum.StrEnum): # Luminal LumIntensity = enum.auto() LumFlux = enum.auto() - Luminance = enum.auto() Illuminance = enum.auto() # Optics OrdinaryWaveVector = enum.auto() @@ -866,7 +869,7 @@ class PhysicalType(enum.StrEnum): PT.OrdinaryWaveVector: Dims.frequency, PT.AngularWaveVector: Dims.angle * Dims.frequency, PT.PoyntingVector: Dims.power / Dims.length**2, - } + }[self] @property def default_unit(self) -> list[Unit]: @@ -1072,7 +1075,7 @@ class PhysicalType(enum.StrEnum): @staticmethod def from_unit(unit: Unit) -> list[Unit]: - for physical_type in list[PhysicalType]: + for physical_type in list(PhysicalType): if unit in physical_type.valid_units: return physical_type @@ -1161,7 +1164,7 @@ class PhysicalType(enum.StrEnum): @staticmethod def to_name(value: typ.Self) -> str: - return sp_to_str(value.unit_dim) + return PhysicalType(value).name @staticmethod def to_icon(value: typ.Self) -> str: @@ -1208,6 +1211,7 @@ UNITS_SI: UnitSystem = { # Electrodynamics _PT.Current: spu.ampere, _PT.CurrentDensity: spu.ampere / spu.meter**2, + _PT.Voltage: spu.volt, _PT.Capacitance: spu.farad, _PT.Impedance: spu.ohm, _PT.Conductance: spu.siemens, @@ -1278,13 +1282,12 @@ def sympy_to_python( #################### # - Convert to Unit System #################### -def _flat_unit_system_units(unit_system: UnitSystem) -> SympyExpr: - return list(unit_system.values()) - - def convert_to_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr: """Convert an expression to the units of a given unit system, with appropriate scaling.""" - return spu.convert_to(sp_obj, _flat_unit_system_units(unit_system)) + return spu.convert_to( + sp_obj, + {unit_system[PhysicalType.from_unit(unit)] for unit in get_units(sp_obj)}, + ) def strip_unit_system(sp_obj: SympyExpr, unit_system: UnitSystem) -> SympyExpr: