Coverage for custom_components/supernotify/scenario.py: 85%

88 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-10-18 09:29 +0000

1import logging 

2from collections.abc import Iterator 

3from contextlib import contextmanager 

4from dataclasses import asdict 

5from typing import Any 

6 

7from homeassistant.components.trace import async_setup, async_store_trace # type: ignore[attr-defined] 

8from homeassistant.components.trace.const import DATA_TRACE 

9from homeassistant.components.trace.models import ActionTrace 

10from homeassistant.const import ( 

11 CONF_ALIAS, 

12 CONF_CONDITION, 

13) 

14from homeassistant.core import Context, HomeAssistant 

15from homeassistant.helpers import condition 

16from homeassistant.helpers.trace import trace_get, trace_path 

17from homeassistant.helpers.typing import ConfigType 

18from voluptuous import Invalid 

19 

20from . import ATTR_DEFAULT, CONF_ACTION_GROUP_NAMES, CONF_DELIVERY, CONF_DELIVERY_SELECTION, CONF_MEDIA, ConditionVariables 

21 

22_LOGGER = logging.getLogger(__name__) 

23 

24 

25class Scenario: 

26 def __init__(self, name: str, scenario_definition: dict[str, Any], hass: HomeAssistant) -> None: 

27 self.hass: HomeAssistant = hass 

28 self.name: str = name 

29 self.alias: str | None = scenario_definition.get(CONF_ALIAS) 

30 self.condition: ConfigType | None = scenario_definition.get(CONF_CONDITION) 

31 self.media: dict[str, Any] | None = scenario_definition.get(CONF_MEDIA) 

32 self.delivery_selection: str | None = scenario_definition.get(CONF_DELIVERY_SELECTION) 

33 self.action_groups: list[str] = scenario_definition.get(CONF_ACTION_GROUP_NAMES, []) 

34 self.delivery: dict[str, Any] = scenario_definition.get(CONF_DELIVERY) or {} 

35 self.default: bool = self.name == ATTR_DEFAULT 

36 self.last_trace: ActionTrace | None = None 

37 self.condition_func = None 

38 

39 async def validate(self) -> bool: 

40 """Validate Home Assistant conditiion definition at initiation""" 

41 if self.condition: 

42 try: 

43 cond = await condition.async_validate_condition_config(self.hass, self.condition) 

44 if await condition.async_from_config(self.hass, cond) is None: 

45 _LOGGER.warning("SUPERNOTIFY Disabling scenario %s with failed condition %s", self.name, self.condition) 

46 return False 

47 except Exception as e: 

48 _LOGGER.error("SUPERNOTIFY Disabling scenario %s with error validating %s: %s", self.name, self.condition, e) 

49 return False 

50 return True 

51 

52 def attributes(self, include_condition: bool = True, include_trace: bool = False) -> dict[str, Any]: 

53 """Return scenario attributes""" 

54 attrs = { 

55 "name": self.name, 

56 "alias": self.alias, 

57 "media": self.media, 

58 "delivery_selection": self.delivery_selection, 

59 "action_groups": self.action_groups, 

60 "delivery": self.delivery, 

61 "default": self.default, 

62 } 

63 if include_condition: 

64 attrs["condition"] = self.condition 

65 if include_trace and self.last_trace: 

66 attrs["trace"] = self.last_trace.as_extended_dict() 

67 return attrs 

68 

69 def contents(self, minimal: bool = False) -> dict[str, Any]: 

70 """Archive friendly view of scenario""" 

71 return self.attributes(include_condition=not minimal, include_trace=not minimal) 

72 

73 async def evaluate(self, condition_variables: ConditionVariables | None = None) -> bool: 

74 """Evaluate scenario conditions""" 

75 if self.condition: 

76 try: 

77 test = await condition.async_from_config(self.hass, self.condition) 

78 if test is None: 

79 raise Invalid(f"Empty condition generated for {self.name}") 

80 except Exception as e: 

81 _LOGGER.error("SUPERNOTIFY Scenario %s condition create failed: %s", self.name, e) 

82 return False 

83 try: 

84 if test(self.hass, asdict(condition_variables) if condition_variables else None): 

85 return True 

86 except Exception as e: 

87 _LOGGER.error( 

88 "SUPERNOTIFY Scenario condition eval failed: %s, vars: %s", 

89 e, 

90 condition_variables.as_dict() if condition_variables else {}, 

91 ) 

92 return False 

93 

94 async def trace(self, condition_variables: ConditionVariables | None = None, config: ConfigType | None = None) -> bool: 

95 """Trace scenario delivery""" 

96 result = None 

97 config = {} if config is None else config 

98 if DATA_TRACE not in self.hass.data: 

99 await async_setup(self.hass, config) 

100 with trace_action(self.hass, f"scenario_{self.name}", config) as scenario_trace: 

101 scenario_trace.set_trace(trace_get()) 

102 self.last_trace = scenario_trace 

103 with trace_path(["condition", "conditions"]) as _tp: 

104 result = await self.evaluate(condition_variables) 

105 _LOGGER.info(scenario_trace.as_dict()) 

106 return result 

107 

108 

109@contextmanager 

110def trace_action( 

111 hass: HomeAssistant, 

112 item_id: str, 

113 config: dict[str, Any], 

114 context: Context | None = None, 

115 stored_traces: int = 5, 

116) -> Iterator[ActionTrace]: 

117 """Trace execution of a scenario.""" 

118 trace = ActionTrace(item_id, config, None, context or Context()) 

119 async_store_trace(hass, trace, stored_traces) 

120 

121 try: 

122 yield trace 

123 except Exception as ex: 

124 if item_id: 

125 trace.set_error(ex) 

126 raise 

127 finally: 

128 if item_id: 

129 trace.finished()