mirror of
https://github.com/kuhyx/signal-bot.git
synced 2026-07-04 13:03:06 +02:00
- Remove unused imports from main.py - Add clearer documentation for extract_envelope_source_uuid - Use Iterable[str] type hint instead of tuple Co-authored-by: kuhyx <147418882+kuhyx@users.noreply.github.com>
361 lines
12 KiB
Python
361 lines
12 KiB
Python
"""
|
|
Abstract and generic rule system for triggering bot commands.
|
|
|
|
This module provides a flexible rule system that allows defining conditions
|
|
under which commands should be triggered. Rules can be combined using
|
|
logical AND/OR operations.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from datetime import datetime, time, timedelta
|
|
from typing import Optional, Dict, Any, List, Iterable
|
|
|
|
|
|
class CommandContext:
|
|
"""Context object containing information about the current command execution."""
|
|
|
|
def __init__(
|
|
self,
|
|
command: str,
|
|
user_id: str,
|
|
reaction_count: int = 0,
|
|
current_time: Optional[datetime] = None,
|
|
message_data: Optional[Dict[str, Any]] = None
|
|
):
|
|
self.command = command
|
|
self.user_id = user_id
|
|
self.reaction_count = reaction_count
|
|
self.current_time = current_time or datetime.now()
|
|
self.message_data = message_data or {}
|
|
|
|
|
|
class CommandState:
|
|
"""Tracks state for command usage (usage count, last user, last time)."""
|
|
|
|
def __init__(self):
|
|
# Per-command state: command -> {usage_count, last_user, last_time}
|
|
self._command_states: Dict[str, Dict[str, Any]] = {}
|
|
|
|
def get_state(self, command: str) -> Dict[str, Any]:
|
|
"""Get state for a specific command."""
|
|
if command not in self._command_states:
|
|
self._command_states[command] = {
|
|
'usage_count': 0,
|
|
'last_user': None,
|
|
'last_time': None
|
|
}
|
|
return self._command_states[command]
|
|
|
|
def record_usage(self, command: str, user_id: str, timestamp: Optional[datetime] = None):
|
|
"""Record that a command was used."""
|
|
state = self.get_state(command)
|
|
state['usage_count'] += 1
|
|
state['last_user'] = user_id
|
|
state['last_time'] = timestamp or datetime.now()
|
|
|
|
def get_usage_count(self, command: str) -> int:
|
|
"""Get total usage count for a command."""
|
|
return self.get_state(command)['usage_count']
|
|
|
|
def get_last_user(self, command: str) -> Optional[str]:
|
|
"""Get the last user who used a command."""
|
|
return self.get_state(command)['last_user']
|
|
|
|
def get_last_time(self, command: str) -> Optional[datetime]:
|
|
"""Get the last time a command was used."""
|
|
return self.get_state(command)['last_time']
|
|
|
|
def reset_command(self, command: str):
|
|
"""Reset state for a specific command."""
|
|
if command in self._command_states:
|
|
del self._command_states[command]
|
|
|
|
def reset_all(self):
|
|
"""Reset all command states."""
|
|
self._command_states.clear()
|
|
|
|
|
|
class Rule(ABC):
|
|
"""Abstract base class for command trigger rules."""
|
|
|
|
@abstractmethod
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
"""
|
|
Evaluate whether the rule passes.
|
|
|
|
Args:
|
|
context: The current command context
|
|
state: The command state tracker
|
|
|
|
Returns:
|
|
True if the rule passes (command should be allowed), False otherwise
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_failure_message(self) -> Optional[str]:
|
|
"""
|
|
Get a message explaining why the rule failed.
|
|
|
|
Returns:
|
|
A human-readable message or None if no message should be sent
|
|
"""
|
|
pass
|
|
|
|
|
|
class MinReactionsRule(Rule):
|
|
"""Rule that passes if the message has at least n reactions."""
|
|
|
|
def __init__(self, min_reactions: int, failure_message: Optional[str] = None):
|
|
self.min_reactions = min_reactions
|
|
self.failure_message = failure_message
|
|
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
return context.reaction_count >= self.min_reactions
|
|
|
|
def get_failure_message(self) -> Optional[str]:
|
|
return self.failure_message
|
|
|
|
|
|
class MaxUsageRule(Rule):
|
|
"""Rule that passes if total command usage has not exceeded n."""
|
|
|
|
def __init__(self, max_usage: int, failure_message: Optional[str] = None):
|
|
self.max_usage = max_usage
|
|
self.failure_message = failure_message
|
|
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
current_usage = state.get_usage_count(context.command)
|
|
return current_usage < self.max_usage
|
|
|
|
def get_failure_message(self) -> Optional[str]:
|
|
return self.failure_message
|
|
|
|
|
|
class CooldownRule(Rule):
|
|
"""Rule that passes if sufficient time has passed since last command use."""
|
|
|
|
def __init__(self, cooldown: timedelta, failure_message: Optional[str] = None):
|
|
self.cooldown = cooldown
|
|
self.failure_message = failure_message
|
|
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
last_time = state.get_last_time(context.command)
|
|
if last_time is None:
|
|
return True
|
|
return context.current_time - last_time >= self.cooldown
|
|
|
|
def get_failure_message(self) -> Optional[str]:
|
|
return self.failure_message
|
|
|
|
|
|
class NotUserRule(Rule):
|
|
"""Rule that passes if the command was not last used by a specific user."""
|
|
|
|
def __init__(self, excluded_user: str, failure_message: Optional[str] = None):
|
|
self.excluded_user = excluded_user
|
|
self.failure_message = failure_message
|
|
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
last_user = state.get_last_user(context.command)
|
|
if last_user is None:
|
|
return True
|
|
return last_user != self.excluded_user
|
|
|
|
def get_failure_message(self) -> Optional[str]:
|
|
return self.failure_message
|
|
|
|
|
|
class NotSameUserRule(Rule):
|
|
"""Rule that passes if the current user is not the same as the last user."""
|
|
|
|
def __init__(self, failure_message: Optional[str] = None):
|
|
self.failure_message = failure_message
|
|
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
last_user = state.get_last_user(context.command)
|
|
if last_user is None:
|
|
return True
|
|
return last_user != context.user_id
|
|
|
|
def get_failure_message(self) -> Optional[str]:
|
|
return self.failure_message
|
|
|
|
|
|
class TimeWindowRule(Rule):
|
|
"""Rule that passes if current time is within the specified window."""
|
|
|
|
def __init__(
|
|
self,
|
|
start_time: time,
|
|
end_time: time,
|
|
failure_message: Optional[str] = None
|
|
):
|
|
self.start_time = start_time
|
|
self.end_time = end_time
|
|
self.failure_message = failure_message
|
|
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
current_time = context.current_time.time()
|
|
|
|
# Handle window that spans midnight (e.g., 22:00 to 06:00)
|
|
if self.start_time <= self.end_time:
|
|
return self.start_time <= current_time <= self.end_time
|
|
else:
|
|
return current_time >= self.start_time or current_time <= self.end_time
|
|
|
|
def get_failure_message(self) -> Optional[str]:
|
|
return self.failure_message
|
|
|
|
|
|
class AlwaysPassRule(Rule):
|
|
"""Rule that always passes. Useful as a default or placeholder."""
|
|
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
return True
|
|
|
|
def get_failure_message(self) -> Optional[str]:
|
|
return None
|
|
|
|
|
|
class AlwaysFailRule(Rule):
|
|
"""Rule that always fails. Useful for disabling commands."""
|
|
|
|
def __init__(self, failure_message: Optional[str] = None):
|
|
self.failure_message = failure_message
|
|
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
return False
|
|
|
|
def get_failure_message(self) -> Optional[str]:
|
|
return self.failure_message
|
|
|
|
|
|
class AndRule(Rule):
|
|
"""Composite rule that passes only if ALL child rules pass."""
|
|
|
|
def __init__(self, rules: List[Rule]):
|
|
self.rules = rules
|
|
self._failed_rule: Optional[Rule] = None
|
|
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
self._failed_rule = None
|
|
for rule in self.rules:
|
|
if not rule.evaluate(context, state):
|
|
self._failed_rule = rule
|
|
return False
|
|
return True
|
|
|
|
def get_failure_message(self) -> Optional[str]:
|
|
if self._failed_rule is not None:
|
|
return self._failed_rule.get_failure_message()
|
|
return None
|
|
|
|
|
|
class OrRule(Rule):
|
|
"""Composite rule that passes if ANY child rule passes."""
|
|
|
|
def __init__(self, rules: List[Rule], failure_message: Optional[str] = None):
|
|
self.rules = rules
|
|
self.failure_message = failure_message
|
|
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
for rule in self.rules:
|
|
if rule.evaluate(context, state):
|
|
return True
|
|
return False
|
|
|
|
def get_failure_message(self) -> Optional[str]:
|
|
return self.failure_message
|
|
|
|
|
|
class GlobalCooldownRule(Rule):
|
|
"""Rule that enforces a cooldown across all commands (not per-command)."""
|
|
|
|
def __init__(self, cooldown: timedelta, failure_message: Optional[str] = None):
|
|
self.cooldown = cooldown
|
|
self.failure_message = failure_message
|
|
self._last_global_time: Optional[datetime] = None
|
|
|
|
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
|
|
if self._last_global_time is None:
|
|
return True
|
|
return context.current_time - self._last_global_time >= self.cooldown
|
|
|
|
def record_usage(self, timestamp: Optional[datetime] = None):
|
|
"""Record that a command was used globally."""
|
|
self._last_global_time = timestamp or datetime.now()
|
|
|
|
def get_failure_message(self) -> Optional[str]:
|
|
return self.failure_message
|
|
|
|
|
|
class RuleEngine:
|
|
"""
|
|
Engine for evaluating rules and managing command state.
|
|
|
|
This provides a higher-level interface for working with rules.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.state = CommandState()
|
|
self._command_rules: Dict[str, Rule] = {}
|
|
self._default_rule: Rule = AlwaysPassRule()
|
|
self._global_rules: List[Rule] = []
|
|
|
|
def set_command_rule(self, command: str, rule: Rule):
|
|
"""Set the rule for a specific command."""
|
|
self._command_rules[command] = rule
|
|
|
|
def set_command_rules(self, commands: Iterable[str], rule: Rule):
|
|
"""Set the same rule for multiple commands."""
|
|
for command in commands:
|
|
self._command_rules[command] = rule
|
|
|
|
def set_default_rule(self, rule: Rule):
|
|
"""Set the default rule for commands without specific rules."""
|
|
self._default_rule = rule
|
|
|
|
def add_global_rule(self, rule: Rule):
|
|
"""Add a rule that applies to all commands."""
|
|
self._global_rules.append(rule)
|
|
|
|
def clear_global_rules(self):
|
|
"""Clear all global rules."""
|
|
self._global_rules.clear()
|
|
|
|
def get_rule(self, command: str) -> Rule:
|
|
"""Get the rule for a specific command."""
|
|
return self._command_rules.get(command, self._default_rule)
|
|
|
|
def evaluate(self, context: CommandContext) -> tuple:
|
|
"""
|
|
Evaluate whether a command should be triggered.
|
|
|
|
Args:
|
|
context: The command context
|
|
|
|
Returns:
|
|
Tuple of (should_trigger: bool, failure_message: Optional[str])
|
|
"""
|
|
# First check global rules
|
|
for global_rule in self._global_rules:
|
|
if not global_rule.evaluate(context, self.state):
|
|
return False, global_rule.get_failure_message()
|
|
|
|
# Then check command-specific rule
|
|
rule = self.get_rule(context.command)
|
|
if not rule.evaluate(context, self.state):
|
|
return False, rule.get_failure_message()
|
|
|
|
return True, None
|
|
|
|
def record_usage(self, command: str, user_id: str, timestamp: Optional[datetime] = None):
|
|
"""Record that a command was used."""
|
|
self.state.record_usage(command, user_id, timestamp)
|
|
|
|
# Also record for global rules that need it
|
|
for global_rule in self._global_rules:
|
|
if isinstance(global_rule, GlobalCooldownRule):
|
|
global_rule.record_usage(timestamp)
|