""" Abstract and generic rule system for triggering bot commands. This module provides a flexible rule system that allows defining conditions under which commands should be triggered. Rules can be combined using logical AND/OR operations. """ from abc import ABC, abstractmethod from datetime import datetime, time, timedelta from typing import Optional, Dict, Any, List, Iterable class CommandContext: """Context object containing information about the current command execution.""" def __init__( self, command: str, user_id: str, reaction_count: int = 0, current_time: Optional[datetime] = None, message_data: Optional[Dict[str, Any]] = None ): self.command = command self.user_id = user_id self.reaction_count = reaction_count self.current_time = current_time or datetime.now() self.message_data = message_data or {} class CommandState: """Tracks state for command usage (usage count, last user, last time).""" def __init__(self): # Per-command state: command -> {usage_count, last_user, last_time} self._command_states: Dict[str, Dict[str, Any]] = {} def get_state(self, command: str) -> Dict[str, Any]: """Get state for a specific command.""" if command not in self._command_states: self._command_states[command] = { 'usage_count': 0, 'last_user': None, 'last_time': None } return self._command_states[command] def record_usage(self, command: str, user_id: str, timestamp: Optional[datetime] = None): """Record that a command was used.""" state = self.get_state(command) state['usage_count'] += 1 state['last_user'] = user_id state['last_time'] = timestamp or datetime.now() def get_usage_count(self, command: str) -> int: """Get total usage count for a command.""" return self.get_state(command)['usage_count'] def get_last_user(self, command: str) -> Optional[str]: """Get the last user who used a command.""" return self.get_state(command)['last_user'] def get_last_time(self, command: str) -> Optional[datetime]: """Get the last time a command was used.""" return self.get_state(command)['last_time'] def reset_command(self, command: str): """Reset state for a specific command.""" if command in self._command_states: del self._command_states[command] def reset_all(self): """Reset all command states.""" self._command_states.clear() class Rule(ABC): """Abstract base class for command trigger rules.""" @abstractmethod def evaluate(self, context: CommandContext, state: CommandState) -> bool: """ Evaluate whether the rule passes. Args: context: The current command context state: The command state tracker Returns: True if the rule passes (command should be allowed), False otherwise """ pass @abstractmethod def get_failure_message(self) -> Optional[str]: """ Get a message explaining why the rule failed. Returns: A human-readable message or None if no message should be sent """ pass class MinReactionsRule(Rule): """Rule that passes if the message has at least n reactions.""" def __init__(self, min_reactions: int, failure_message: Optional[str] = None): self.min_reactions = min_reactions self.failure_message = failure_message def evaluate(self, context: CommandContext, state: CommandState) -> bool: return context.reaction_count >= self.min_reactions def get_failure_message(self) -> Optional[str]: return self.failure_message class MaxUsageRule(Rule): """Rule that passes if total command usage has not exceeded n.""" def __init__(self, max_usage: int, failure_message: Optional[str] = None): self.max_usage = max_usage self.failure_message = failure_message def evaluate(self, context: CommandContext, state: CommandState) -> bool: current_usage = state.get_usage_count(context.command) return current_usage < self.max_usage def get_failure_message(self) -> Optional[str]: return self.failure_message class CooldownRule(Rule): """Rule that passes if sufficient time has passed since last command use.""" def __init__(self, cooldown: timedelta, failure_message: Optional[str] = None): self.cooldown = cooldown self.failure_message = failure_message def evaluate(self, context: CommandContext, state: CommandState) -> bool: last_time = state.get_last_time(context.command) if last_time is None: return True return context.current_time - last_time >= self.cooldown def get_failure_message(self) -> Optional[str]: return self.failure_message class NotUserRule(Rule): """Rule that passes if the command was not last used by a specific user.""" def __init__(self, excluded_user: str, failure_message: Optional[str] = None): self.excluded_user = excluded_user self.failure_message = failure_message def evaluate(self, context: CommandContext, state: CommandState) -> bool: last_user = state.get_last_user(context.command) if last_user is None: return True return last_user != self.excluded_user def get_failure_message(self) -> Optional[str]: return self.failure_message class NotSameUserRule(Rule): """Rule that passes if the current user is not the same as the last user.""" def __init__(self, failure_message: Optional[str] = None): self.failure_message = failure_message def evaluate(self, context: CommandContext, state: CommandState) -> bool: last_user = state.get_last_user(context.command) if last_user is None: return True return last_user != context.user_id def get_failure_message(self) -> Optional[str]: return self.failure_message class TimeWindowRule(Rule): """Rule that passes if current time is within the specified window.""" def __init__( self, start_time: time, end_time: time, failure_message: Optional[str] = None ): self.start_time = start_time self.end_time = end_time self.failure_message = failure_message def evaluate(self, context: CommandContext, state: CommandState) -> bool: current_time = context.current_time.time() # Handle window that spans midnight (e.g., 22:00 to 06:00) if self.start_time <= self.end_time: return self.start_time <= current_time <= self.end_time else: return current_time >= self.start_time or current_time <= self.end_time def get_failure_message(self) -> Optional[str]: return self.failure_message class AlwaysPassRule(Rule): """Rule that always passes. Useful as a default or placeholder.""" def evaluate(self, context: CommandContext, state: CommandState) -> bool: return True def get_failure_message(self) -> Optional[str]: return None class AlwaysFailRule(Rule): """Rule that always fails. Useful for disabling commands.""" def __init__(self, failure_message: Optional[str] = None): self.failure_message = failure_message def evaluate(self, context: CommandContext, state: CommandState) -> bool: return False def get_failure_message(self) -> Optional[str]: return self.failure_message class AndRule(Rule): """Composite rule that passes only if ALL child rules pass.""" def __init__(self, rules: List[Rule]): self.rules = rules self._failed_rule: Optional[Rule] = None def evaluate(self, context: CommandContext, state: CommandState) -> bool: self._failed_rule = None for rule in self.rules: if not rule.evaluate(context, state): self._failed_rule = rule return False return True def get_failure_message(self) -> Optional[str]: if self._failed_rule is not None: return self._failed_rule.get_failure_message() return None class OrRule(Rule): """Composite rule that passes if ANY child rule passes.""" def __init__(self, rules: List[Rule], failure_message: Optional[str] = None): self.rules = rules self.failure_message = failure_message def evaluate(self, context: CommandContext, state: CommandState) -> bool: for rule in self.rules: if rule.evaluate(context, state): return True return False def get_failure_message(self) -> Optional[str]: return self.failure_message class GlobalCooldownRule(Rule): """Rule that enforces a cooldown across all commands (not per-command).""" def __init__(self, cooldown: timedelta, failure_message: Optional[str] = None): self.cooldown = cooldown self.failure_message = failure_message self._last_global_time: Optional[datetime] = None def evaluate(self, context: CommandContext, state: CommandState) -> bool: if self._last_global_time is None: return True return context.current_time - self._last_global_time >= self.cooldown def record_usage(self, timestamp: Optional[datetime] = None): """Record that a command was used globally.""" self._last_global_time = timestamp or datetime.now() def get_failure_message(self) -> Optional[str]: return self.failure_message class RuleEngine: """ Engine for evaluating rules and managing command state. This provides a higher-level interface for working with rules. """ def __init__(self): self.state = CommandState() self._command_rules: Dict[str, Rule] = {} self._default_rule: Rule = AlwaysPassRule() self._global_rules: List[Rule] = [] def set_command_rule(self, command: str, rule: Rule): """Set the rule for a specific command.""" self._command_rules[command] = rule def set_command_rules(self, commands: Iterable[str], rule: Rule): """Set the same rule for multiple commands.""" for command in commands: self._command_rules[command] = rule def set_default_rule(self, rule: Rule): """Set the default rule for commands without specific rules.""" self._default_rule = rule def add_global_rule(self, rule: Rule): """Add a rule that applies to all commands.""" self._global_rules.append(rule) def clear_global_rules(self): """Clear all global rules.""" self._global_rules.clear() def get_rule(self, command: str) -> Rule: """Get the rule for a specific command.""" return self._command_rules.get(command, self._default_rule) def evaluate(self, context: CommandContext) -> tuple: """ Evaluate whether a command should be triggered. Args: context: The command context Returns: Tuple of (should_trigger: bool, failure_message: Optional[str]) """ # First check global rules for global_rule in self._global_rules: if not global_rule.evaluate(context, self.state): return False, global_rule.get_failure_message() # Then check command-specific rule rule = self.get_rule(context.command) if not rule.evaluate(context, self.state): return False, rule.get_failure_message() return True, None def record_usage(self, command: str, user_id: str, timestamp: Optional[datetime] = None): """Record that a command was used.""" self.state.record_usage(command, user_id, timestamp) # Also record for global rules that need it for global_rule in self._global_rules: if isinstance(global_rule, GlobalCooldownRule): global_rule.record_usage(timestamp)