mirror of
https://github.com/kuhyx/signal-bot.git
synced 2026-07-04 13:23:07 +02:00
366 lines
13 KiB
Python
366 lines
13 KiB
Python
import os
|
||
import asyncio
|
||
import websockets
|
||
import requests
|
||
import base64
|
||
import json
|
||
from datetime import datetime, time, timedelta
|
||
from fastapi import FastAPI
|
||
from rule34Py import rule34Py
|
||
|
||
# Create FastAPI app
|
||
app = FastAPI()
|
||
|
||
PHONE_NUMBER = os.getenv('PHONE_NUMBER', '1234567890')
|
||
RECEIVE_URL = f"http://localhost:9922/v1/receive/{PHONE_NUMBER}"
|
||
REMOVE_ATTACHMENT_URL = f"http://localhost:9922/v1/attachments/"
|
||
SEND_URL = 'http://localhost:9922/v2/send'
|
||
GROUP_ID = os.getenv('GROUP_ID', '')
|
||
GROUP_ID_SEND = os.getenv('GROUP_ID_SEND', '')
|
||
CAT_API = os.getenv('CAT_API', '')
|
||
last_command_time = None
|
||
warning_sent = False
|
||
|
||
# Initialize rule34Py for trap images
|
||
r34_client = rule34Py()
|
||
|
||
class StringCounter:
|
||
def __init__(self):
|
||
self.string_map = {}
|
||
|
||
async def update_string_map(self, key, common_name):
|
||
if key in self.string_map:
|
||
self.string_map[key]['count'] += 1
|
||
else:
|
||
self.string_map[key] = {'common_name': common_name, 'count': 1}
|
||
return self.string_map
|
||
|
||
def get_common_name(self, key):
|
||
if key in self.string_map:
|
||
return self.string_map[key]['common_name']
|
||
return None
|
||
|
||
|
||
class VotingSystem:
|
||
"""Tracks votes for commands that require multiple users to agree."""
|
||
REQUIRED_VOTES = 3
|
||
VOTE_TIMEOUT_MINUTES = 15
|
||
|
||
def __init__(self):
|
||
self.votes = {} # command -> {user_uuid: timestamp}
|
||
|
||
def add_vote(self, command, user_uuid):
|
||
"""Add a vote from a user. Returns (vote_count, newly_passed)."""
|
||
current_time = datetime.now()
|
||
if command not in self.votes:
|
||
self.votes[command] = {}
|
||
|
||
# Remove expired votes
|
||
self._cleanup_expired_votes(command, current_time)
|
||
|
||
# Add or update the user's vote
|
||
already_voted = user_uuid in self.votes[command]
|
||
self.votes[command][user_uuid] = current_time
|
||
|
||
vote_count = len(self.votes[command])
|
||
# Check if we just reached the threshold (exactly equals, not >=)
|
||
newly_passed = vote_count == self.REQUIRED_VOTES and not already_voted
|
||
|
||
return vote_count, newly_passed
|
||
|
||
def _cleanup_expired_votes(self, command, current_time):
|
||
"""Remove votes older than VOTE_TIMEOUT_MINUTES."""
|
||
if command not in self.votes:
|
||
return
|
||
timeout = timedelta(minutes=self.VOTE_TIMEOUT_MINUTES)
|
||
self.votes[command] = {
|
||
user_uuid: timestamp
|
||
for user_uuid, timestamp in self.votes[command].items()
|
||
if current_time - timestamp < timeout
|
||
}
|
||
|
||
def reset_votes(self, command):
|
||
"""Reset votes for a command after it has been triggered."""
|
||
if command in self.votes:
|
||
self.votes[command] = {}
|
||
|
||
def get_vote_count(self, command):
|
||
"""Get current vote count for a command."""
|
||
current_time = datetime.now()
|
||
self._cleanup_expired_votes(command, current_time)
|
||
return len(self.votes.get(command, {}))
|
||
|
||
|
||
# Global voting system instance
|
||
voting_system = VotingSystem()
|
||
|
||
def download_image(image_url):
|
||
# Download the image
|
||
image_response = requests.get(image_url)
|
||
image_response.raise_for_status() # Ensure the request was successful
|
||
|
||
# Extract the image filename from the URL
|
||
image_filename = image_url.split("/")[-1]
|
||
|
||
with open(image_filename, 'wb') as image_file:
|
||
image_file.write(image_response.content)
|
||
|
||
# Convert the image to base64 encoded data
|
||
with open(image_filename, 'rb') as image_file:
|
||
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
|
||
|
||
os.remove(image_filename)
|
||
return base64_encoded_data
|
||
|
||
def fetch_and_download_image(api_url, image_url_key):
|
||
# Send request to the API
|
||
response = requests.get(api_url)
|
||
response.raise_for_status() # Ensure the request was successful
|
||
|
||
# Parse the response JSON to get the image URL
|
||
data = response.json()
|
||
|
||
# Retrieve the image URL based on the provided key
|
||
if isinstance(image_url_key, list):
|
||
image_url = data
|
||
for key in image_url_key:
|
||
image_url = image_url[key]
|
||
else:
|
||
image_url = data[image_url_key]
|
||
|
||
return download_image(image_url)
|
||
|
||
async def send_image(base64_attachments, recipients=GROUP_ID_SEND):
|
||
data = {
|
||
"base64_attachments": [base64_attachments],
|
||
"number": PHONE_NUMBER,
|
||
"recipients": [recipients]
|
||
}
|
||
response = requests.post(SEND_URL, json=data)
|
||
if response.status_code == 200 or response.status_code == 201:
|
||
print("Request was successful.")
|
||
else:
|
||
print(f"Request failed with status code: {response.status_code}")
|
||
print(response.text)
|
||
|
||
def send_message(message_content, recipients=PHONE_NUMBER):
|
||
data = {
|
||
"message": str(message_content),
|
||
"number": PHONE_NUMBER,
|
||
"recipients": [recipients]
|
||
}
|
||
response = requests.post(SEND_URL, json=data)
|
||
if response.status_code == 200:
|
||
print("Request was successful.")
|
||
else:
|
||
print(f"Request failed with status code: {response.status_code} {data}")
|
||
print(response.text)
|
||
|
||
|
||
def message_message(inside_message):
|
||
message_value = inside_message.get('message')
|
||
return message_value
|
||
|
||
def message_group_id(inside_message):
|
||
return inside_message.get('groupInfo', {}).get('groupId', {})
|
||
|
||
def extract_message_content(message):
|
||
message_json = json.loads(message)
|
||
inside_message = message_json.get('envelope', {}).get('dataMessage', {})
|
||
if inside_message == {}:
|
||
inside_message = message_json.get('envelope', {}).get('syncMessage', {}).get('sentMessage', {})
|
||
return inside_message
|
||
|
||
|
||
def extract_source_uuid(message):
|
||
message_json = message
|
||
inside_message = message_json.get('sourceUuid', {})
|
||
return inside_message
|
||
|
||
command_map = {
|
||
("!kot", "!koty", "!kots", "!cat", "!cats", "!meow", "!miau", "!ᴋᴏᴛ", "!𝓴𝓸𝓽", "!𝗸𝗼𝘁"): lambda recipient: send_image(fetch_and_download_image("https://api.thecatapi.com/v1/images/search", [0, 'url']), recipient),
|
||
("!pies", "!psy", "!dog", "!dogs", "!woof", "!szczek", "!𝗽𝗶𝗲𝘀", "!͓̽p͓̽i͓̽e͓̽s͓̽"): lambda recipient: send_image(fetch_and_download_image("https://dog.ceo/api/breeds/image/random", 'message'), recipient),
|
||
}
|
||
|
||
# Commands that require voting (3+ votes within 15 minutes)
|
||
VOTING_COMMANDS = ("!traps", "!trap")
|
||
|
||
|
||
async def send_trap_image(recipient):
|
||
"""Fetch and send a trap image from rule34."""
|
||
try:
|
||
post = r34_client.random_post(["trap"])
|
||
if post and post.sample:
|
||
base64_data = download_image(post.sample)
|
||
await send_image(base64_data, recipient)
|
||
else:
|
||
send_message("Nie znaleziono obrazka.", recipient)
|
||
except Exception as e:
|
||
send_message(f"Błąd podczas pobierania obrazka: {str(e)}", recipient)
|
||
|
||
def extract_source_name(message):
|
||
message_json = message
|
||
inside_message = message_json.get('sourceName', {})
|
||
return inside_message
|
||
|
||
|
||
USER_MESSAGE_COUNT = {}
|
||
|
||
|
||
async def count_messages(message_content, counter):
|
||
if message_content and should_count(message_content):
|
||
uuid = extract_source_uuid(message_content)
|
||
source_name = extract_source_name(message_content)
|
||
await counter.update_string_map(uuid, source_name)
|
||
send_message(counter.string_map, PHONE_NUMBER)
|
||
|
||
|
||
async def scheduled_task(counter):
|
||
while True:
|
||
now = datetime.now()
|
||
target_time = datetime.combine(now.date(), time(21, 37))
|
||
if now > target_time:
|
||
target_time += timedelta(days=1)
|
||
wait_time = (target_time - now).total_seconds()
|
||
await asyncio.sleep(wait_time)
|
||
# Trigger your function here
|
||
send_message(counter.string_map, GROUP_ID_SEND)
|
||
counter.string_map = {}
|
||
|
||
async def trigger_command(message_content, recipient, user_uuid=None):
|
||
global last_command_time, warning_sent
|
||
message_value = message_message(message_content)
|
||
|
||
try:
|
||
if message_value is not None and message_value[0] == "!":
|
||
current_time = datetime.now()
|
||
|
||
# Handle voting commands separately (no cooldown for voting)
|
||
if message_value in VOTING_COMMANDS:
|
||
if user_uuid:
|
||
await handle_voting_command(message_value, user_uuid, recipient)
|
||
return
|
||
|
||
if last_command_time and current_time - last_command_time < timedelta(seconds=10):
|
||
if not warning_sent:
|
||
send_message("BEEP BOOP POCZEKAJ 10 SEKUND.", recipient)
|
||
warning_sent = True
|
||
return
|
||
|
||
for command_triggers, command_function in command_map.items():
|
||
if message_value in command_triggers:
|
||
await command_function(recipient)
|
||
last_command_time = current_time
|
||
warning_sent = False
|
||
break
|
||
except TypeError:
|
||
send_message(f"trigger_command, TypeError {message_content}", recipient)
|
||
except Exception as e:
|
||
send_message(f"trigger_command, unknown error {message_content}: {str(e)}", recipient)
|
||
|
||
|
||
async def handle_voting_command(command, user_uuid, recipient):
|
||
"""Handle commands that require voting."""
|
||
global last_command_time, warning_sent
|
||
# Use a normalized command key for voting (both !trap and !traps map to same vote)
|
||
vote_key = "traps"
|
||
vote_count, newly_passed = voting_system.add_vote(vote_key, user_uuid)
|
||
required = VotingSystem.REQUIRED_VOTES
|
||
timeout = VotingSystem.VOTE_TIMEOUT_MINUTES
|
||
|
||
if newly_passed:
|
||
# Check cooldown only when threshold is reached
|
||
current_time = datetime.now()
|
||
if last_command_time and current_time - last_command_time < timedelta(seconds=10):
|
||
if not warning_sent:
|
||
send_message("BEEP BOOP POCZEKAJ 10 SEKUND.", recipient)
|
||
warning_sent = True
|
||
return
|
||
|
||
send_message(f"Głosowanie zakończone! ({vote_count}/{required}) Wysyłam obrazek...", recipient)
|
||
await send_trap_image(recipient)
|
||
voting_system.reset_votes(vote_key)
|
||
last_command_time = current_time
|
||
warning_sent = False
|
||
elif vote_count >= required:
|
||
# Already passed in a previous vote, just inform
|
||
send_message("Już wysłano obrazek. Głosowanie zresetowane.", recipient)
|
||
voting_system.reset_votes(vote_key)
|
||
else:
|
||
remaining_votes = required - vote_count
|
||
send_message(
|
||
f"Głos zapisany! ({vote_count}/{required}) "
|
||
f"Potrzeba jeszcze {remaining_votes} głos(ów) w ciągu {timeout} minut.",
|
||
recipient
|
||
)
|
||
|
||
async def send_to_group(message_content, counter, message):
|
||
if message_group_id(message_content) == GROUP_ID:
|
||
envelope = json.loads(message).get('envelope', {})
|
||
await count_messages(envelope, counter)
|
||
user_uuid = extract_source_uuid(envelope)
|
||
await trigger_command(message_content, GROUP_ID_SEND, user_uuid)
|
||
|
||
async def remove_attachment(attachment_id):
|
||
response = requests.delete(REMOVE_ATTACHMENT_URL + attachment_id)
|
||
if response.status_code == 200 or response.status_code == 204:
|
||
print("Request remove_attachment was successful.")
|
||
else:
|
||
print(f"Request remove_attachment failed with status code: {response.status_code}")
|
||
print(response)
|
||
|
||
async def get_attachments():
|
||
response = requests.get(REMOVE_ATTACHMENT_URL)
|
||
if response.status_code == 200:
|
||
attachments = json.loads(response.content)
|
||
print("attachments: ", attachments)
|
||
for attachment in attachments:
|
||
print("attachment: ", attachment)
|
||
await remove_attachment(attachment)
|
||
else:
|
||
print(f"Request failed with status code: {response.status_code}")
|
||
print(response)
|
||
|
||
def is_message_reaction(message):
|
||
message_json = json.loads(message)
|
||
inside_message = message_json.get('envelope', {}).get('dataMessage', {}).get('reaction', {})
|
||
if inside_message == {}:
|
||
inside_message = message_json.get('envelope', {}).get('syncMessage', {}).get('reaction', {})
|
||
if inside_message != {}:
|
||
return True
|
||
return False
|
||
|
||
def should_count(message_content):
|
||
print("should_count triggered")
|
||
#if message_content.get('destinationNumber', {}) != PHONE_NUMBER:
|
||
# print("not counting because destinationNumber != PHONE_NUMBER")
|
||
# return False
|
||
sticker = message_content.get("dataMessage", {}).get("sticker", {})
|
||
print("sticker ", sticker)
|
||
if sticker != {}:
|
||
print("not counting because message has a sticker ", message_content)
|
||
return False
|
||
print("counting message: ", message_content)
|
||
return True
|
||
|
||
async def listen_to_server(counter):
|
||
uri = f"ws://localhost:9922/v1/receive/{PHONE_NUMBER}?send_read_receipts=false"
|
||
async with websockets.connect(uri) as websocket:
|
||
print(f"Connected to signal server")
|
||
try:
|
||
async for message in websocket:
|
||
if is_message_reaction(message) == False:
|
||
print("message: ", message)
|
||
message_content = extract_message_content(message)
|
||
await send_to_group(message_content, counter, message)
|
||
except websockets.ConnectionClosed as e:
|
||
print(f"Connection closed: {e}")
|
||
|
||
# Endpoint to start the asyncio server tasks
|
||
@app.on_event("startup")
|
||
async def start_tasks():
|
||
counter = StringCounter()
|
||
task1 = asyncio.create_task(listen_to_server(counter))
|
||
task2 = asyncio.create_task(scheduled_task(counter))
|
||
await asyncio.gather(task1, task2)
|