Coverage for custom_components/supernotify/scenario.py: 85%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-12-28 14:21 +0000

1import logging 

2from collections.abc import Iterator 

3from contextlib import contextmanager 

4from dataclasses import asdict 

5from typing import Any 

6 

7from homeassistant.components.trace import async_setup, async_store_trace # type: ignore 

8from homeassistant.components.trace.const import DATA_TRACE 

9from homeassistant.components.trace.models import ActionTrace 

10from homeassistant.const import ( 

11 CONF_ALIAS, 

12 CONF_CONDITION, 

13) 

14from homeassistant.core import Context, HomeAssistant 

15from homeassistant.helpers import condition 

16from homeassistant.helpers.trace import trace_get, trace_path 

17from homeassistant.helpers.typing import ConfigType 

18from voluptuous import Invalid 

19 

20from . import ATTR_DEFAULT, CONF_ACTION_GROUP_NAMES, CONF_DELIVERY, CONF_DELIVERY_SELECTION, CONF_MEDIA, ConditionVariables 

21 

22_LOGGER = logging.getLogger(__name__) 

23 

24 

25class Scenario: 

26 def __init__(self, name: str, scenario_definition: dict, hass: HomeAssistant) -> None: 

27 self.hass: HomeAssistant = hass 

28 self.name: str = name 

29 self.alias: str | None = scenario_definition.get(CONF_ALIAS) 

30 self.condition: ConfigType | None = scenario_definition.get(CONF_CONDITION) 

31 self.media: dict | None = scenario_definition.get(CONF_MEDIA) 

32 self.delivery_selection: str | None = scenario_definition.get(CONF_DELIVERY_SELECTION) 

33 self.action_groups: list[str] = scenario_definition.get(CONF_ACTION_GROUP_NAMES, []) 

34 self.delivery: dict = scenario_definition.get(CONF_DELIVERY) or {} 

35 self.default: bool = self.name == ATTR_DEFAULT 

36 self.last_trace: ActionTrace | None = None 

37 self.condition_func = None 

38 

39 async def validate(self) -> bool: 

40 """Validate Home Assistant conditiion definition at initiation""" 

41 if self.condition: 

42 try: 

43 cond = await condition.async_validate_condition_config(self.hass, self.condition) 

44 if await condition.async_from_config(self.hass, cond) is None: 

45 _LOGGER.warning("SUPERNOTIFY Disabling scenario %s with failed condition %s", self.name, self.condition) 

46 return False 

47 except Exception as e: 

48 _LOGGER.error("SUPERNOTIFY Disabling scenario %s with error validating %s: %s", self.name, self.condition, e) 

49 return False 

50 return True 

51 

52 def attributes( 

53 self, include_condition: bool = True, include_trace: bool = False 

54 ) -> dict[str, str | dict | bool | list[str] | None]: 

55 """Return scenario attributes""" 

56 attrs = { 

57 "name": self.name, 

58 "alias": self.alias, 

59 "media": self.media, 

60 "delivery_selection": self.delivery_selection, 

61 "action_groups": self.action_groups, 

62 "delivery": self.delivery, 

63 "default": self.default, 

64 } 

65 if include_condition: 

66 attrs["condition"] = self.condition 

67 if include_trace and self.last_trace: 

68 attrs["trace"] = self.last_trace.as_extended_dict() 

69 return attrs 

70 

71 async def evaluate(self, condition_variables: ConditionVariables | None = None) -> bool: 

72 """Evaluate scenario conditions""" 

73 if self.condition: 

74 try: 

75 test = await condition.async_from_config(self.hass, self.condition) 

76 if test is None: 

77 raise Invalid(f"Empty condition generated for {self.name}") 

78 except Exception as e: 

79 _LOGGER.error("SUPERNOTIFY Scenario %s condition create failed: %s", self.name, e) 

80 return False 

81 try: 

82 if test(self.hass, asdict(condition_variables) if condition_variables else None): 

83 return True 

84 except Exception as e: 

85 _LOGGER.error("SUPERNOTIFY Scenario condition eval failed: %s, vars: %s", e, condition_variables) 

86 return False 

87 

88 async def trace(self, condition_variables: ConditionVariables | None = None, config: ConfigType | None = None) -> bool: 

89 """Trace scenario delivery""" 

90 result = None 

91 config = {} if config is None else config 

92 if DATA_TRACE not in self.hass.data: 

93 await async_setup(self.hass, config) 

94 with trace_action(self.hass, f"scenario_{self.name}", config) as scenario_trace: 

95 scenario_trace.set_trace(trace_get()) 

96 self.last_trace = scenario_trace 

97 with trace_path(["condition", "conditions"]) as _tp: # type: ignore 

98 result = await self.evaluate(condition_variables) 

99 _LOGGER.info(scenario_trace.as_dict()) 

100 return result 

101 

102 

103@contextmanager 

104def trace_action( 

105 hass: HomeAssistant, 

106 item_id: str, 

107 config: dict[str, Any], 

108 context: Context | None = None, 

109 stored_traces: int = 5, 

110) -> Iterator[ActionTrace]: 

111 """Trace execution of a scenario.""" 

112 trace = ActionTrace(item_id, config, None, context or Context()) 

113 async_store_trace(hass, trace, stored_traces) 

114 

115 try: 

116 yield trace 

117 except Exception as ex: 

118 if item_id: 

119 trace.set_error(ex) 

120 raise 

121 finally: 

122 if item_id: 

123 trace.finished()