signal-bot/main.py
copilot-swe-agent[bot] 730073eda6 Address code review: improve error handling for trap command
Co-authored-by: kuhyx <147418882+kuhyx@users.noreply.github.com>
2025-12-01 15:18:12 +00:00

307 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import asyncio
import websockets
import requests
import base64
import json
from datetime import datetime, time, timedelta
from fastapi import FastAPI
from rule34Py import rule34Py
# Create FastAPI app
app = FastAPI()
PHONE_NUMBER = os.getenv('PHONE_NUMBER', '1234567890')
RECEIVE_URL = f"http://localhost:9922/v1/receive/{PHONE_NUMBER}"
REMOVE_ATTACHMENT_URL = f"http://localhost:9922/v1/attachments/"
SEND_URL = 'http://localhost:9922/v2/send'
GROUP_ID = os.getenv('GROUP_ID', '')
GROUP_ID_SEND = os.getenv('GROUP_ID_SEND', '')
CAT_API = os.getenv('CAT_API', '')
last_command_time = None
warning_sent = False
# Track user trap usage for 24hr rate limiting
user_trap_usage = {}
class StringCounter:
def __init__(self):
self.string_map = {}
async def update_string_map(self, key, common_name):
if key in self.string_map:
self.string_map[key]['count'] += 1
else:
self.string_map[key] = {'common_name': common_name, 'count': 1}
return self.string_map
def get_common_name(self, key):
if key in self.string_map:
return self.string_map[key]['common_name']
return None
def download_image(image_url):
# Download the image
image_response = requests.get(image_url)
image_response.raise_for_status() # Ensure the request was successful
# Extract the image filename from the URL
image_filename = image_url.split("/")[-1]
with open(image_filename, 'wb') as image_file:
image_file.write(image_response.content)
# Convert the image to base64 encoded data
with open(image_filename, 'rb') as image_file:
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
os.remove(image_filename)
return base64_encoded_data
def fetch_and_download_image(api_url, image_url_key):
# Send request to the API
response = requests.get(api_url)
response.raise_for_status() # Ensure the request was successful
# Parse the response JSON to get the image URL
data = response.json()
# Retrieve the image URL based on the provided key
if isinstance(image_url_key, list):
image_url = data
for key in image_url_key:
image_url = image_url[key]
else:
image_url = data[image_url_key]
return download_image(image_url)
async def send_image(base64_attachments, recipients=GROUP_ID_SEND):
data = {
"base64_attachments": [base64_attachments],
"number": PHONE_NUMBER,
"recipients": [recipients]
}
response = requests.post(SEND_URL, json=data)
if response.status_code == 200 or response.status_code == 201:
print("Request was successful.")
else:
print(f"Request failed with status code: {response.status_code}")
print(response.text)
def send_message(message_content, recipients=PHONE_NUMBER):
data = {
"message": str(message_content),
"number": PHONE_NUMBER,
"recipients": [recipients]
}
response = requests.post(SEND_URL, json=data)
if response.status_code == 200:
print("Request was successful.")
else:
print(f"Request failed with status code: {response.status_code} {data}")
print(response.text)
def get_trap_image():
"""Fetch a random trap image from rule34."""
try:
r34 = rule34Py()
post = r34.random_post()
return download_image(post.image)
except Exception as e:
print(f"Error fetching trap image: {e}")
return None
def can_use_trap(user_uuid):
"""Check if user can use trap command (1 free per 24hr)."""
if user_uuid not in user_trap_usage:
return True
last_usage = user_trap_usage[user_uuid]
return datetime.now() - last_usage >= timedelta(hours=24)
def get_trap_cooldown_remaining(user_uuid):
"""Get remaining cooldown time for trap command."""
if user_uuid not in user_trap_usage:
return timedelta(0)
elapsed = datetime.now() - user_trap_usage[user_uuid]
remaining = timedelta(hours=24) - elapsed
return max(remaining, timedelta(0))
async def handle_trap_command(recipient, user_uuid):
"""Handle trap command with 24hr rate limiting per user."""
if can_use_trap(user_uuid):
image = get_trap_image()
if image:
user_trap_usage[user_uuid] = datetime.now()
await send_image(image, recipient)
else:
send_message("Nie udalo sie pobrac obrazka. Sprobuj ponownie pozniej.", recipient)
else:
remaining = get_trap_cooldown_remaining(user_uuid)
hours = int(remaining.total_seconds() // 3600)
minutes = int((remaining.total_seconds() % 3600) // 60)
send_message(f"Mozesz uzyc !trap raz na 24h. Poczekaj jeszcze {hours}h {minutes}m.", recipient)
def message_message(inside_message):
message_value = inside_message.get('message')
return message_value
def message_group_id(inside_message):
return inside_message.get('groupInfo', {}).get('groupId', {})
def extract_message_content(message):
message_json = json.loads(message)
inside_message = message_json.get('envelope', {}).get('dataMessage', {})
if inside_message == {}:
inside_message = message_json.get('envelope', {}).get('syncMessage', {}).get('sentMessage', {})
return inside_message
def extract_source_uuid(message):
message_json = message
inside_message = message_json.get('sourceUuid', {})
return inside_message
command_map = {
("!kot", "!koty", "!kots", "!cat", "!cats", "!meow", "!miau", "!ᴋᴏᴛ", "!𝓴𝓸𝓽", "!𝗸𝗼𝘁"): lambda recipient: send_image(fetch_and_download_image("https://api.thecatapi.com/v1/images/search", [0, 'url']), recipient),
("!pies", "!psy", "!dog", "!dogs", "!woof", "!szczek", "!𝗽𝗶𝗲𝘀", "!͓̽p͓̽i͓̽e͓̽s͓̽"): lambda recipient: send_image(fetch_and_download_image("https://dog.ceo/api/breeds/image/random", 'message'), recipient),
}
def extract_source_name(message):
message_json = message
inside_message = message_json.get('sourceName', {})
return inside_message
USER_MESSAGE_COUNT = {}
async def count_messages(message_content, counter):
if message_content and should_count(message_content):
uuid = extract_source_uuid(message_content)
source_name = extract_source_name(message_content)
await counter.update_string_map(uuid, source_name)
send_message(counter.string_map, PHONE_NUMBER)
async def scheduled_task(counter):
while True:
now = datetime.now()
target_time = datetime.combine(now.date(), time(21, 37))
if now > target_time:
target_time += timedelta(days=1)
wait_time = (target_time - now).total_seconds()
await asyncio.sleep(wait_time)
# Trigger your function here
send_message(counter.string_map, GROUP_ID_SEND)
counter.string_map = {}
async def trigger_command(message_content, recipient, user_uuid=None):
global last_command_time, warning_sent
message_value = message_message(message_content)
try:
if message_value is not None and message_value[0] == "!":
current_time = datetime.now()
if last_command_time and current_time - last_command_time < timedelta(seconds=10):
if not warning_sent:
send_message("BEEP BOOP POCZEKAJ 10 SEKUND.", recipient)
warning_sent = True
return
# Handle trap command separately due to per-user rate limiting
if message_value in ("!trap", "!traps"):
if user_uuid:
await handle_trap_command(recipient, user_uuid)
last_command_time = current_time
warning_sent = False
else:
send_message("Nie mozna zidentyfikowac uzytkownika.", recipient)
return
for command_triggers, command_function in command_map.items():
if message_value in command_triggers:
await command_function(recipient)
last_command_time = current_time
warning_sent = False
break
except TypeError:
send_message(f"trigger_command, TypeError {message_content}", recipient)
except Exception as e:
send_message(f"trigger_command, unknown error {message_content}: {str(e)}", recipient)
async def send_to_group(message_content, counter, message):
if message_group_id(message_content) == GROUP_ID:
envelope = json.loads(message).get('envelope', {})
await count_messages(envelope, counter)
user_uuid = extract_source_uuid(envelope)
await trigger_command(message_content, GROUP_ID_SEND, user_uuid)
async def remove_attachment(attachment_id):
response = requests.delete(REMOVE_ATTACHMENT_URL + attachment_id)
if response.status_code == 200 or response.status_code == 204:
print("Request remove_attachment was successful.")
else:
print(f"Request remove_attachment failed with status code: {response.status_code}")
print(response)
async def get_attachments():
response = requests.get(REMOVE_ATTACHMENT_URL)
if response.status_code == 200:
attachments = json.loads(response.content)
print("attachments: ", attachments)
for attachment in attachments:
print("attachment: ", attachment)
await remove_attachment(attachment)
else:
print(f"Request failed with status code: {response.status_code}")
print(response)
def is_message_reaction(message):
message_json = json.loads(message)
inside_message = message_json.get('envelope', {}).get('dataMessage', {}).get('reaction', {})
if inside_message == {}:
inside_message = message_json.get('envelope', {}).get('syncMessage', {}).get('reaction', {})
if inside_message != {}:
return True
return False
def should_count(message_content):
print("should_count triggered")
#if message_content.get('destinationNumber', {}) != PHONE_NUMBER:
# print("not counting because destinationNumber != PHONE_NUMBER")
# return False
sticker = message_content.get("dataMessage", {}).get("sticker", {})
print("sticker ", sticker)
if sticker != {}:
print("not counting because message has a sticker ", message_content)
return False
print("counting message: ", message_content)
return True
async def listen_to_server(counter):
uri = f"ws://localhost:9922/v1/receive/{PHONE_NUMBER}?send_read_receipts=false"
async with websockets.connect(uri) as websocket:
print(f"Connected to signal server")
try:
async for message in websocket:
if is_message_reaction(message) == False:
print("message: ", message)
message_content = extract_message_content(message)
await send_to_group(message_content, counter, message)
except websockets.ConnectionClosed as e:
print(f"Connection closed: {e}")
# Endpoint to start the asyncio server tasks
@app.on_event("startup")
async def start_tasks():
counter = StringCounter()
task1 = asyncio.create_task(listen_to_server(counter))
task2 = asyncio.create_task(scheduled_task(counter))
await asyncio.gather(task1, task2)