signal-bot/rules.py

361 lines
12 KiB
Python
Raw Normal View History

"""
Abstract and generic rule system for triggering bot commands.
This module provides a flexible rule system that allows defining conditions
under which commands should be triggered. Rules can be combined using
logical AND/OR operations.
"""
from abc import ABC, abstractmethod
from datetime import datetime, time, timedelta
from typing import Optional, Dict, Any, List, Iterable
class CommandContext:
"""Context object containing information about the current command execution."""
def __init__(
self,
command: str,
user_id: str,
reaction_count: int = 0,
current_time: Optional[datetime] = None,
message_data: Optional[Dict[str, Any]] = None
):
self.command = command
self.user_id = user_id
self.reaction_count = reaction_count
self.current_time = current_time or datetime.now()
self.message_data = message_data or {}
class CommandState:
"""Tracks state for command usage (usage count, last user, last time)."""
def __init__(self):
# Per-command state: command -> {usage_count, last_user, last_time}
self._command_states: Dict[str, Dict[str, Any]] = {}
def get_state(self, command: str) -> Dict[str, Any]:
"""Get state for a specific command."""
if command not in self._command_states:
self._command_states[command] = {
'usage_count': 0,
'last_user': None,
'last_time': None
}
return self._command_states[command]
def record_usage(self, command: str, user_id: str, timestamp: Optional[datetime] = None):
"""Record that a command was used."""
state = self.get_state(command)
state['usage_count'] += 1
state['last_user'] = user_id
state['last_time'] = timestamp or datetime.now()
def get_usage_count(self, command: str) -> int:
"""Get total usage count for a command."""
return self.get_state(command)['usage_count']
def get_last_user(self, command: str) -> Optional[str]:
"""Get the last user who used a command."""
return self.get_state(command)['last_user']
def get_last_time(self, command: str) -> Optional[datetime]:
"""Get the last time a command was used."""
return self.get_state(command)['last_time']
def reset_command(self, command: str):
"""Reset state for a specific command."""
if command in self._command_states:
del self._command_states[command]
def reset_all(self):
"""Reset all command states."""
self._command_states.clear()
class Rule(ABC):
"""Abstract base class for command trigger rules."""
@abstractmethod
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
"""
Evaluate whether the rule passes.
Args:
context: The current command context
state: The command state tracker
Returns:
True if the rule passes (command should be allowed), False otherwise
"""
pass
@abstractmethod
def get_failure_message(self) -> Optional[str]:
"""
Get a message explaining why the rule failed.
Returns:
A human-readable message or None if no message should be sent
"""
pass
class MinReactionsRule(Rule):
"""Rule that passes if the message has at least n reactions."""
def __init__(self, min_reactions: int, failure_message: Optional[str] = None):
self.min_reactions = min_reactions
self.failure_message = failure_message
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
return context.reaction_count >= self.min_reactions
def get_failure_message(self) -> Optional[str]:
return self.failure_message
class MaxUsageRule(Rule):
"""Rule that passes if total command usage has not exceeded n."""
def __init__(self, max_usage: int, failure_message: Optional[str] = None):
self.max_usage = max_usage
self.failure_message = failure_message
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
current_usage = state.get_usage_count(context.command)
return current_usage < self.max_usage
def get_failure_message(self) -> Optional[str]:
return self.failure_message
class CooldownRule(Rule):
"""Rule that passes if sufficient time has passed since last command use."""
def __init__(self, cooldown: timedelta, failure_message: Optional[str] = None):
self.cooldown = cooldown
self.failure_message = failure_message
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
last_time = state.get_last_time(context.command)
if last_time is None:
return True
return context.current_time - last_time >= self.cooldown
def get_failure_message(self) -> Optional[str]:
return self.failure_message
class NotUserRule(Rule):
"""Rule that passes if the command was not last used by a specific user."""
def __init__(self, excluded_user: str, failure_message: Optional[str] = None):
self.excluded_user = excluded_user
self.failure_message = failure_message
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
last_user = state.get_last_user(context.command)
if last_user is None:
return True
return last_user != self.excluded_user
def get_failure_message(self) -> Optional[str]:
return self.failure_message
class NotSameUserRule(Rule):
"""Rule that passes if the current user is not the same as the last user."""
def __init__(self, failure_message: Optional[str] = None):
self.failure_message = failure_message
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
last_user = state.get_last_user(context.command)
if last_user is None:
return True
return last_user != context.user_id
def get_failure_message(self) -> Optional[str]:
return self.failure_message
class TimeWindowRule(Rule):
"""Rule that passes if current time is within the specified window."""
def __init__(
self,
start_time: time,
end_time: time,
failure_message: Optional[str] = None
):
self.start_time = start_time
self.end_time = end_time
self.failure_message = failure_message
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
current_time = context.current_time.time()
# Handle window that spans midnight (e.g., 22:00 to 06:00)
if self.start_time <= self.end_time:
return self.start_time <= current_time <= self.end_time
else:
return current_time >= self.start_time or current_time <= self.end_time
def get_failure_message(self) -> Optional[str]:
return self.failure_message
class AlwaysPassRule(Rule):
"""Rule that always passes. Useful as a default or placeholder."""
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
return True
def get_failure_message(self) -> Optional[str]:
return None
class AlwaysFailRule(Rule):
"""Rule that always fails. Useful for disabling commands."""
def __init__(self, failure_message: Optional[str] = None):
self.failure_message = failure_message
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
return False
def get_failure_message(self) -> Optional[str]:
return self.failure_message
class AndRule(Rule):
"""Composite rule that passes only if ALL child rules pass."""
def __init__(self, rules: List[Rule]):
self.rules = rules
self._failed_rule: Optional[Rule] = None
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
self._failed_rule = None
for rule in self.rules:
if not rule.evaluate(context, state):
self._failed_rule = rule
return False
return True
def get_failure_message(self) -> Optional[str]:
if self._failed_rule is not None:
return self._failed_rule.get_failure_message()
return None
class OrRule(Rule):
"""Composite rule that passes if ANY child rule passes."""
def __init__(self, rules: List[Rule], failure_message: Optional[str] = None):
self.rules = rules
self.failure_message = failure_message
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
for rule in self.rules:
if rule.evaluate(context, state):
return True
return False
def get_failure_message(self) -> Optional[str]:
return self.failure_message
class GlobalCooldownRule(Rule):
"""Rule that enforces a cooldown across all commands (not per-command)."""
def __init__(self, cooldown: timedelta, failure_message: Optional[str] = None):
self.cooldown = cooldown
self.failure_message = failure_message
self._last_global_time: Optional[datetime] = None
def evaluate(self, context: CommandContext, state: CommandState) -> bool:
if self._last_global_time is None:
return True
return context.current_time - self._last_global_time >= self.cooldown
def record_usage(self, timestamp: Optional[datetime] = None):
"""Record that a command was used globally."""
self._last_global_time = timestamp or datetime.now()
def get_failure_message(self) -> Optional[str]:
return self.failure_message
class RuleEngine:
"""
Engine for evaluating rules and managing command state.
This provides a higher-level interface for working with rules.
"""
def __init__(self):
self.state = CommandState()
self._command_rules: Dict[str, Rule] = {}
self._default_rule: Rule = AlwaysPassRule()
self._global_rules: List[Rule] = []
def set_command_rule(self, command: str, rule: Rule):
"""Set the rule for a specific command."""
self._command_rules[command] = rule
def set_command_rules(self, commands: Iterable[str], rule: Rule):
"""Set the same rule for multiple commands."""
for command in commands:
self._command_rules[command] = rule
def set_default_rule(self, rule: Rule):
"""Set the default rule for commands without specific rules."""
self._default_rule = rule
def add_global_rule(self, rule: Rule):
"""Add a rule that applies to all commands."""
self._global_rules.append(rule)
def clear_global_rules(self):
"""Clear all global rules."""
self._global_rules.clear()
def get_rule(self, command: str) -> Rule:
"""Get the rule for a specific command."""
return self._command_rules.get(command, self._default_rule)
def evaluate(self, context: CommandContext) -> tuple:
"""
Evaluate whether a command should be triggered.
Args:
context: The command context
Returns:
Tuple of (should_trigger: bool, failure_message: Optional[str])
"""
# First check global rules
for global_rule in self._global_rules:
if not global_rule.evaluate(context, self.state):
return False, global_rule.get_failure_message()
# Then check command-specific rule
rule = self.get_rule(context.command)
if not rule.evaluate(context, self.state):
return False, rule.get_failure_message()
return True, None
def record_usage(self, command: str, user_id: str, timestamp: Optional[datetime] = None):
"""Record that a command was used."""
self.state.record_usage(command, user_id, timestamp)
# Also record for global rules that need it
for global_rule in self._global_rules:
if isinstance(global_rule, GlobalCooldownRule):
global_rule.record_usage(timestamp)