signal-bot/test_rules.py
copilot-swe-agent[bot] 8a6dbca7c9 Add abstract and generic rule system for command triggers
- Create rules.py module with Rule base class and implementations:
  - MinReactionsRule: Trigger if received n reactions or more
  - MaxUsageRule: Trigger if total usage did not exceed n
  - CooldownRule: Trigger if cooldown time has elapsed
  - NotUserRule: Trigger if last user was not person x
  - NotSameUserRule: Prevent same user from using consecutively
  - TimeWindowRule: Trigger if time is between t1 and t2
  - AndRule/OrRule: Combine rules with AND/OR logic
  - GlobalCooldownRule: Rate limit across all commands

- Update main.py to use the new RuleEngine
- Add comprehensive test suite (60 tests)

Co-authored-by: kuhyx <147418882+kuhyx@users.noreply.github.com>
2025-12-01 15:13:25 +00:00

733 lines
25 KiB
Python

"""
Tests for the command trigger rules system.
"""
import pytest
from datetime import datetime, time, timedelta
from rules import (
CommandContext,
CommandState,
RuleEngine,
MinReactionsRule,
MaxUsageRule,
CooldownRule,
NotUserRule,
NotSameUserRule,
TimeWindowRule,
AlwaysPassRule,
AlwaysFailRule,
AndRule,
OrRule,
GlobalCooldownRule,
)
class TestCommandState:
"""Tests for CommandState class."""
def test_initial_state(self):
state = CommandState()
assert state.get_usage_count("!test") == 0
assert state.get_last_user("!test") is None
assert state.get_last_time("!test") is None
def test_record_usage(self):
state = CommandState()
timestamp = datetime(2024, 1, 15, 12, 0, 0)
state.record_usage("!test", "user123", timestamp)
assert state.get_usage_count("!test") == 1
assert state.get_last_user("!test") == "user123"
assert state.get_last_time("!test") == timestamp
def test_multiple_usages(self):
state = CommandState()
state.record_usage("!test", "user1")
state.record_usage("!test", "user2")
state.record_usage("!test", "user3")
assert state.get_usage_count("!test") == 3
assert state.get_last_user("!test") == "user3"
def test_separate_commands(self):
state = CommandState()
state.record_usage("!cmd1", "user1")
state.record_usage("!cmd2", "user2")
assert state.get_usage_count("!cmd1") == 1
assert state.get_usage_count("!cmd2") == 1
assert state.get_last_user("!cmd1") == "user1"
assert state.get_last_user("!cmd2") == "user2"
def test_reset_command(self):
state = CommandState()
state.record_usage("!test", "user1")
state.reset_command("!test")
assert state.get_usage_count("!test") == 0
assert state.get_last_user("!test") is None
def test_reset_all(self):
state = CommandState()
state.record_usage("!cmd1", "user1")
state.record_usage("!cmd2", "user2")
state.reset_all()
assert state.get_usage_count("!cmd1") == 0
assert state.get_usage_count("!cmd2") == 0
class TestMinReactionsRule:
"""Tests for MinReactionsRule."""
def test_passes_when_enough_reactions(self):
rule = MinReactionsRule(min_reactions=5)
context = CommandContext(command="!test", user_id="user1", reaction_count=5)
state = CommandState()
assert rule.evaluate(context, state) is True
def test_passes_when_more_reactions(self):
rule = MinReactionsRule(min_reactions=5)
context = CommandContext(command="!test", user_id="user1", reaction_count=10)
state = CommandState()
assert rule.evaluate(context, state) is True
def test_fails_when_fewer_reactions(self):
rule = MinReactionsRule(min_reactions=5)
context = CommandContext(command="!test", user_id="user1", reaction_count=4)
state = CommandState()
assert rule.evaluate(context, state) is False
def test_failure_message(self):
rule = MinReactionsRule(min_reactions=5, failure_message="Need more reactions!")
assert rule.get_failure_message() == "Need more reactions!"
class TestMaxUsageRule:
"""Tests for MaxUsageRule."""
def test_passes_when_below_limit(self):
rule = MaxUsageRule(max_usage=10)
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is True
def test_passes_at_limit_minus_one(self):
rule = MaxUsageRule(max_usage=10)
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
for _ in range(9):
state.record_usage("!test", "user")
assert rule.evaluate(context, state) is True
def test_fails_at_limit(self):
rule = MaxUsageRule(max_usage=10)
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
for _ in range(10):
state.record_usage("!test", "user")
assert rule.evaluate(context, state) is False
def test_fails_above_limit(self):
rule = MaxUsageRule(max_usage=10)
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
for _ in range(15):
state.record_usage("!test", "user")
assert rule.evaluate(context, state) is False
class TestCooldownRule:
"""Tests for CooldownRule."""
def test_passes_when_no_previous_usage(self):
rule = CooldownRule(cooldown=timedelta(seconds=10))
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is True
def test_passes_when_cooldown_elapsed(self):
rule = CooldownRule(cooldown=timedelta(seconds=10))
state = CommandState()
past_time = datetime(2024, 1, 15, 12, 0, 0)
current_time = datetime(2024, 1, 15, 12, 0, 15) # 15 seconds later
state.record_usage("!test", "user1", past_time)
context = CommandContext(command="!test", user_id="user2", current_time=current_time)
assert rule.evaluate(context, state) is True
def test_fails_when_cooldown_not_elapsed(self):
rule = CooldownRule(cooldown=timedelta(seconds=10))
state = CommandState()
past_time = datetime(2024, 1, 15, 12, 0, 0)
current_time = datetime(2024, 1, 15, 12, 0, 5) # 5 seconds later
state.record_usage("!test", "user1", past_time)
context = CommandContext(command="!test", user_id="user2", current_time=current_time)
assert rule.evaluate(context, state) is False
def test_exact_cooldown_boundary(self):
rule = CooldownRule(cooldown=timedelta(seconds=10))
state = CommandState()
past_time = datetime(2024, 1, 15, 12, 0, 0)
current_time = datetime(2024, 1, 15, 12, 0, 10) # Exactly 10 seconds later
state.record_usage("!test", "user1", past_time)
context = CommandContext(command="!test", user_id="user2", current_time=current_time)
assert rule.evaluate(context, state) is True
class TestNotUserRule:
"""Tests for NotUserRule."""
def test_passes_when_no_previous_usage(self):
rule = NotUserRule(excluded_user="banned_user")
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is True
def test_passes_when_different_user(self):
rule = NotUserRule(excluded_user="banned_user")
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
state.record_usage("!test", "other_user")
assert rule.evaluate(context, state) is True
def test_fails_when_excluded_user(self):
rule = NotUserRule(excluded_user="banned_user")
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
state.record_usage("!test", "banned_user")
assert rule.evaluate(context, state) is False
class TestNotSameUserRule:
"""Tests for NotSameUserRule."""
def test_passes_when_no_previous_usage(self):
rule = NotSameUserRule()
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is True
def test_passes_when_different_user(self):
rule = NotSameUserRule()
context = CommandContext(command="!test", user_id="user2")
state = CommandState()
state.record_usage("!test", "user1")
assert rule.evaluate(context, state) is True
def test_fails_when_same_user(self):
rule = NotSameUserRule()
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
state.record_usage("!test", "user1")
assert rule.evaluate(context, state) is False
class TestTimeWindowRule:
"""Tests for TimeWindowRule."""
def test_passes_within_window(self):
rule = TimeWindowRule(start_time=time(9, 0), end_time=time(17, 0))
context = CommandContext(
command="!test",
user_id="user1",
current_time=datetime(2024, 1, 15, 12, 0, 0) # 12:00
)
state = CommandState()
assert rule.evaluate(context, state) is True
def test_passes_at_start(self):
rule = TimeWindowRule(start_time=time(9, 0), end_time=time(17, 0))
context = CommandContext(
command="!test",
user_id="user1",
current_time=datetime(2024, 1, 15, 9, 0, 0) # 09:00
)
state = CommandState()
assert rule.evaluate(context, state) is True
def test_passes_at_end(self):
rule = TimeWindowRule(start_time=time(9, 0), end_time=time(17, 0))
context = CommandContext(
command="!test",
user_id="user1",
current_time=datetime(2024, 1, 15, 17, 0, 0) # 17:00
)
state = CommandState()
assert rule.evaluate(context, state) is True
def test_fails_before_window(self):
rule = TimeWindowRule(start_time=time(9, 0), end_time=time(17, 0))
context = CommandContext(
command="!test",
user_id="user1",
current_time=datetime(2024, 1, 15, 8, 0, 0) # 08:00
)
state = CommandState()
assert rule.evaluate(context, state) is False
def test_fails_after_window(self):
rule = TimeWindowRule(start_time=time(9, 0), end_time=time(17, 0))
context = CommandContext(
command="!test",
user_id="user1",
current_time=datetime(2024, 1, 15, 18, 0, 0) # 18:00
)
state = CommandState()
assert rule.evaluate(context, state) is False
def test_overnight_window_passes_evening(self):
# Window from 22:00 to 06:00 (overnight)
rule = TimeWindowRule(start_time=time(22, 0), end_time=time(6, 0))
context = CommandContext(
command="!test",
user_id="user1",
current_time=datetime(2024, 1, 15, 23, 0, 0) # 23:00
)
state = CommandState()
assert rule.evaluate(context, state) is True
def test_overnight_window_passes_morning(self):
# Window from 22:00 to 06:00 (overnight)
rule = TimeWindowRule(start_time=time(22, 0), end_time=time(6, 0))
context = CommandContext(
command="!test",
user_id="user1",
current_time=datetime(2024, 1, 15, 4, 0, 0) # 04:00
)
state = CommandState()
assert rule.evaluate(context, state) is True
def test_overnight_window_fails_midday(self):
# Window from 22:00 to 06:00 (overnight)
rule = TimeWindowRule(start_time=time(22, 0), end_time=time(6, 0))
context = CommandContext(
command="!test",
user_id="user1",
current_time=datetime(2024, 1, 15, 12, 0, 0) # 12:00
)
state = CommandState()
assert rule.evaluate(context, state) is False
class TestAlwaysPassRule:
"""Tests for AlwaysPassRule."""
def test_always_passes(self):
rule = AlwaysPassRule()
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is True
def test_no_failure_message(self):
rule = AlwaysPassRule()
assert rule.get_failure_message() is None
class TestAlwaysFailRule:
"""Tests for AlwaysFailRule."""
def test_always_fails(self):
rule = AlwaysFailRule()
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is False
def test_failure_message(self):
rule = AlwaysFailRule(failure_message="Command disabled")
assert rule.get_failure_message() == "Command disabled"
class TestAndRule:
"""Tests for AndRule composite rule."""
def test_passes_when_all_pass(self):
rule = AndRule([
AlwaysPassRule(),
AlwaysPassRule(),
AlwaysPassRule()
])
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is True
def test_fails_when_one_fails(self):
rule = AndRule([
AlwaysPassRule(),
AlwaysFailRule(),
AlwaysPassRule()
])
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is False
def test_failure_message_from_failed_rule(self):
rule = AndRule([
AlwaysPassRule(),
AlwaysFailRule(failure_message="This one failed"),
AlwaysPassRule()
])
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
rule.evaluate(context, state)
assert rule.get_failure_message() == "This one failed"
def test_empty_rules_pass(self):
rule = AndRule([])
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is True
class TestOrRule:
"""Tests for OrRule composite rule."""
def test_passes_when_one_passes(self):
rule = OrRule([
AlwaysFailRule(),
AlwaysPassRule(),
AlwaysFailRule()
])
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is True
def test_passes_when_all_pass(self):
rule = OrRule([
AlwaysPassRule(),
AlwaysPassRule(),
AlwaysPassRule()
])
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is True
def test_fails_when_all_fail(self):
rule = OrRule([
AlwaysFailRule(),
AlwaysFailRule(),
AlwaysFailRule()
])
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is False
def test_failure_message(self):
rule = OrRule([
AlwaysFailRule(),
AlwaysFailRule()
], failure_message="None passed")
assert rule.get_failure_message() == "None passed"
def test_empty_rules_fail(self):
rule = OrRule([])
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is False
class TestGlobalCooldownRule:
"""Tests for GlobalCooldownRule."""
def test_passes_when_no_previous_usage(self):
rule = GlobalCooldownRule(cooldown=timedelta(seconds=10))
context = CommandContext(command="!test", user_id="user1")
state = CommandState()
assert rule.evaluate(context, state) is True
def test_passes_when_cooldown_elapsed(self):
rule = GlobalCooldownRule(cooldown=timedelta(seconds=10))
state = CommandState()
past_time = datetime(2024, 1, 15, 12, 0, 0)
current_time = datetime(2024, 1, 15, 12, 0, 15) # 15 seconds later
rule.record_usage(past_time)
context = CommandContext(command="!test", user_id="user1", current_time=current_time)
assert rule.evaluate(context, state) is True
def test_fails_when_cooldown_not_elapsed(self):
rule = GlobalCooldownRule(cooldown=timedelta(seconds=10))
state = CommandState()
past_time = datetime(2024, 1, 15, 12, 0, 0)
current_time = datetime(2024, 1, 15, 12, 0, 5) # 5 seconds later
rule.record_usage(past_time)
context = CommandContext(command="!test", user_id="user1", current_time=current_time)
assert rule.evaluate(context, state) is False
class TestRuleEngine:
"""Tests for RuleEngine class."""
def test_default_always_pass(self):
engine = RuleEngine()
context = CommandContext(command="!test", user_id="user1")
should_trigger, message = engine.evaluate(context)
assert should_trigger is True
assert message is None
def test_command_specific_rule(self):
engine = RuleEngine()
engine.set_command_rule("!test", AlwaysFailRule(failure_message="Nope"))
context = CommandContext(command="!test", user_id="user1")
should_trigger, message = engine.evaluate(context)
assert should_trigger is False
assert message == "Nope"
def test_command_specific_rule_different_command(self):
engine = RuleEngine()
engine.set_command_rule("!test", AlwaysFailRule(failure_message="Nope"))
# Different command should use default rule
context = CommandContext(command="!other", user_id="user1")
should_trigger, message = engine.evaluate(context)
assert should_trigger is True
assert message is None
def test_set_command_rules_multiple(self):
engine = RuleEngine()
engine.set_command_rules(("!cat", "!kot", "!meow"), AlwaysFailRule())
for cmd in ["!cat", "!kot", "!meow"]:
context = CommandContext(command=cmd, user_id="user1")
should_trigger, _ = engine.evaluate(context)
assert should_trigger is False
def test_global_rule_blocks(self):
engine = RuleEngine()
engine.add_global_rule(AlwaysFailRule(failure_message="Global block"))
context = CommandContext(command="!test", user_id="user1")
should_trigger, message = engine.evaluate(context)
assert should_trigger is False
assert message == "Global block"
def test_global_rule_evaluated_first(self):
engine = RuleEngine()
engine.add_global_rule(AlwaysFailRule(failure_message="Global"))
engine.set_command_rule("!test", AlwaysFailRule(failure_message="Command"))
context = CommandContext(command="!test", user_id="user1")
_, message = engine.evaluate(context)
# Global rule should be checked first
assert message == "Global"
def test_record_usage(self):
engine = RuleEngine()
engine.record_usage("!test", "user1")
assert engine.state.get_usage_count("!test") == 1
assert engine.state.get_last_user("!test") == "user1"
def test_complex_rule_combination(self):
"""Test combining multiple rules."""
engine = RuleEngine()
# Command can only be used 5 times total, with a 30-second cooldown,
# and only between 9am and 5pm
rule = AndRule([
MaxUsageRule(max_usage=5, failure_message="Max usage reached"),
CooldownRule(cooldown=timedelta(seconds=30), failure_message="Cooldown active"),
TimeWindowRule(
start_time=time(9, 0),
end_time=time(17, 0),
failure_message="Outside working hours"
)
])
engine.set_command_rule("!test", rule)
# First usage during working hours should pass
context = CommandContext(
command="!test",
user_id="user1",
current_time=datetime(2024, 1, 15, 12, 0, 0)
)
should_trigger, message = engine.evaluate(context)
assert should_trigger is True
# Record the usage
engine.record_usage("!test", "user1", datetime(2024, 1, 15, 12, 0, 0))
# Same time should fail due to cooldown
context = CommandContext(
command="!test",
user_id="user2",
current_time=datetime(2024, 1, 15, 12, 0, 5)
)
should_trigger, message = engine.evaluate(context)
assert should_trigger is False
assert message == "Cooldown active"
# Outside working hours should fail
context = CommandContext(
command="!test",
user_id="user1",
current_time=datetime(2024, 1, 15, 20, 0, 0)
)
should_trigger, message = engine.evaluate(context)
assert should_trigger is False
assert message == "Outside working hours"
class TestIntegration:
"""Integration tests for realistic use cases."""
def test_rate_limited_command(self):
"""Test a command with a 10-second rate limit."""
engine = RuleEngine()
engine.add_global_rule(
GlobalCooldownRule(
cooldown=timedelta(seconds=10),
failure_message="Wait 10 seconds"
)
)
base_time = datetime(2024, 1, 15, 12, 0, 0)
# First command passes
context = CommandContext(command="!cat", user_id="user1", current_time=base_time)
should_trigger, _ = engine.evaluate(context)
assert should_trigger is True
engine.record_usage("!cat", "user1", base_time)
# 5 seconds later fails
context = CommandContext(
command="!dog",
user_id="user2",
current_time=base_time + timedelta(seconds=5)
)
should_trigger, message = engine.evaluate(context)
assert should_trigger is False
assert message == "Wait 10 seconds"
# 15 seconds later passes
context = CommandContext(
command="!dog",
user_id="user2",
current_time=base_time + timedelta(seconds=15)
)
should_trigger, _ = engine.evaluate(context)
assert should_trigger is True
def test_reaction_gated_command(self):
"""Test a command that requires 5 reactions to trigger."""
engine = RuleEngine()
engine.set_command_rule(
"!special",
MinReactionsRule(min_reactions=5, failure_message="Need 5 reactions")
)
# 3 reactions - fails
context = CommandContext(command="!special", user_id="user1", reaction_count=3)
should_trigger, message = engine.evaluate(context)
assert should_trigger is False
assert message == "Need 5 reactions"
# 5 reactions - passes
context = CommandContext(command="!special", user_id="user1", reaction_count=5)
should_trigger, _ = engine.evaluate(context)
assert should_trigger is True
def test_prevent_spam_from_same_user(self):
"""Test preventing the same user from using a command twice in a row."""
engine = RuleEngine()
engine.set_command_rule(
"!share",
NotSameUserRule(failure_message="Wait for someone else")
)
# First usage passes
context = CommandContext(command="!share", user_id="user1")
should_trigger, _ = engine.evaluate(context)
assert should_trigger is True
engine.record_usage("!share", "user1")
# Same user fails
context = CommandContext(command="!share", user_id="user1")
should_trigger, message = engine.evaluate(context)
assert should_trigger is False
assert message == "Wait for someone else"
# Different user passes
context = CommandContext(command="!share", user_id="user2")
should_trigger, _ = engine.evaluate(context)
assert should_trigger is True
def test_limited_usage_command(self):
"""Test a command with a maximum usage count."""
engine = RuleEngine()
engine.set_command_rule(
"!prize",
MaxUsageRule(max_usage=3, failure_message="Prize limit reached")
)
# First 3 uses pass
for i in range(3):
context = CommandContext(command="!prize", user_id=f"user{i}")
should_trigger, _ = engine.evaluate(context)
assert should_trigger is True
engine.record_usage("!prize", f"user{i}")
# 4th use fails
context = CommandContext(command="!prize", user_id="user4")
should_trigger, message = engine.evaluate(context)
assert should_trigger is False
assert message == "Prize limit reached"
if __name__ == "__main__":
pytest.main([__file__, "-v"])