signal-bot/main.go
copilot-swe-agent[bot] 096a5cb4df Address code review feedback and fix security issues
Co-authored-by: kuhyx <147418882+kuhyx@users.noreply.github.com>
2025-12-01 15:19:29 +00:00

578 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
// Configuration from environment variables
var (
phoneNumber string
receiveURL string
removeAttachmentURL string
sendURL = "http://localhost:9922/v2/send"
groupID string
groupIDSend string
catAPI string
lastCommandTime time.Time
warningSent bool
commandMutex sync.Mutex
)
// StringCounter tracks message counts by user UUID
type StringCounter struct {
StringMap map[string]UserCount
mu sync.RWMutex
}
// UserCount stores common name and message count
type UserCount struct {
CommonName string `json:"common_name"`
Count int `json:"count"`
}
// NewStringCounter creates a new StringCounter
func NewStringCounter() *StringCounter {
return &StringCounter{
StringMap: make(map[string]UserCount),
}
}
// UpdateStringMap updates the count for a given user
func (sc *StringCounter) UpdateStringMap(key, commonName string) map[string]UserCount {
sc.mu.Lock()
defer sc.mu.Unlock()
if uc, exists := sc.StringMap[key]; exists {
uc.Count++
sc.StringMap[key] = uc
} else {
sc.StringMap[key] = UserCount{CommonName: commonName, Count: 1}
}
return sc.StringMap
}
// GetCommonName returns the common name for a UUID
func (sc *StringCounter) GetCommonName(key string) string {
sc.mu.RLock()
defer sc.mu.RUnlock()
if uc, exists := sc.StringMap[key]; exists {
return uc.CommonName
}
return ""
}
// Reset clears the string map
func (sc *StringCounter) Reset() {
sc.mu.Lock()
defer sc.mu.Unlock()
sc.StringMap = make(map[string]UserCount)
}
// GetStringMapJSON returns the string map as JSON
func (sc *StringCounter) GetStringMapJSON() string {
sc.mu.RLock()
defer sc.mu.RUnlock()
data, err := json.Marshal(sc.StringMap)
if err != nil {
return "{}"
}
return string(data)
}
// Message structures for parsing Signal messages
type SignalMessage struct {
Envelope Envelope `json:"envelope"`
Account string `json:"account"`
}
type Envelope struct {
Source string `json:"source"`
SourceNumber string `json:"sourceNumber"`
SourceUuid string `json:"sourceUuid"`
SourceName string `json:"sourceName"`
SourceDevice int `json:"sourceDevice"`
Timestamp int64 `json:"timestamp"`
DataMessage DataMessage `json:"dataMessage,omitempty"`
SyncMessage SyncMessage `json:"syncMessage,omitempty"`
}
type DataMessage struct {
Timestamp int64 `json:"timestamp"`
Message string `json:"message"`
ExpiresInSeconds int `json:"expiresInSeconds"`
ViewOnce bool `json:"viewOnce"`
GroupInfo GroupInfo `json:"groupInfo,omitempty"`
Sticker Sticker `json:"sticker,omitempty"`
Reaction Reaction `json:"reaction,omitempty"`
}
type SyncMessage struct {
SentMessage SentMessage `json:"sentMessage,omitempty"`
Reaction Reaction `json:"reaction,omitempty"`
}
type SentMessage struct {
Destination string `json:"destination"`
DestinationNumber string `json:"destinationNumber"`
DestinationUuid string `json:"destinationUuid"`
Timestamp int64 `json:"timestamp"`
Message string `json:"message"`
ExpiresInSeconds int `json:"expiresInSeconds"`
ViewOnce bool `json:"viewOnce"`
Sticker Sticker `json:"sticker,omitempty"`
GroupInfo GroupInfo `json:"groupInfo,omitempty"`
}
type GroupInfo struct {
GroupID string `json:"groupId"`
Type string `json:"type"`
}
type Sticker struct {
PackID string `json:"packId,omitempty"`
StickerID int `json:"stickerId,omitempty"`
}
type Reaction struct {
Emoji string `json:"emoji,omitempty"`
TargetAuthor string `json:"targetAuthor,omitempty"`
TargetTimestamp int64 `json:"targetTimestamp,omitempty"`
}
// SendMessageRequest is the request body for sending messages
type SendMessageRequest struct {
Message string `json:"message,omitempty"`
Base64Attachments []string `json:"base64_attachments,omitempty"`
Number string `json:"number"`
Recipients []string `json:"recipients"`
}
// Command triggers
// These include various Unicode variants for compatibility with different input methods
// and to match the original Python implementation's command variations
var catCommands = []string{"!kot", "!koty", "!kots", "!cat", "!cats", "!meow", "!miau", "!ᴋᴏᴛ", "!𝓴𝓸𝓽", "!𝗸𝗼𝘁"}
var dogCommands = []string{"!pies", "!psy", "!dog", "!dogs", "!woof", "!szczek", "!𝗽𝗶𝗲𝘀", "!͓̽p͓̽i͓̽e͓̽s͓̽"}
func init() {
phoneNumber = getEnv("PHONE_NUMBER", "1234567890")
receiveURL = fmt.Sprintf("http://localhost:9922/v1/receive/%s", phoneNumber)
removeAttachmentURL = "http://localhost:9922/v1/attachments/"
groupID = getEnv("GROUP_ID", "")
groupIDSend = getEnv("GROUP_ID_SEND", "")
catAPI = getEnv("CAT_API", "")
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// downloadImage downloads an image from URL and returns base64 encoded data
func downloadImage(imageURL string) (string, error) {
resp, err := http.Get(imageURL)
if err != nil {
return "", fmt.Errorf("failed to download image: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("image download failed with status: %d", resp.StatusCode)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read image data: %w", err)
}
return base64.StdEncoding.EncodeToString(data), nil
}
// fetchCatImage fetches a random cat image from TheCatAPI
func fetchCatImage() (string, error) {
resp, err := http.Get("https://api.thecatapi.com/v1/images/search")
if err != nil {
return "", fmt.Errorf("failed to fetch cat API: %w", err)
}
defer resp.Body.Close()
var result []map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to decode cat API response: %w", err)
}
if len(result) == 0 {
return "", fmt.Errorf("no cat images found")
}
imageURL, ok := result[0]["url"].(string)
if !ok {
return "", fmt.Errorf("invalid cat image URL")
}
return downloadImage(imageURL)
}
// fetchDogImage fetches a random dog image from Dog CEO API
func fetchDogImage() (string, error) {
resp, err := http.Get("https://dog.ceo/api/breeds/image/random")
if err != nil {
return "", fmt.Errorf("failed to fetch dog API: %w", err)
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to decode dog API response: %w", err)
}
imageURL, ok := result["message"].(string)
if !ok {
return "", fmt.Errorf("invalid dog image URL")
}
return downloadImage(imageURL)
}
// sendImage sends a base64 encoded image to a recipient
func sendImage(base64Data, recipient string) error {
reqBody := SendMessageRequest{
Base64Attachments: []string{base64Data},
Number: phoneNumber,
Recipients: []string{recipient},
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
resp, err := http.Post(sendURL, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to send image: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated {
log.Println("Image sent successfully.")
return nil
}
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("send image failed with status %d: %s", resp.StatusCode, string(body))
}
// sendMessage sends a text message to a recipient
func sendMessage(messageContent, recipient string) error {
reqBody := SendMessageRequest{
Message: messageContent,
Number: phoneNumber,
Recipients: []string{recipient},
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
resp, err := http.Post(sendURL, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to send message: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
log.Println("Message sent successfully.")
return nil
}
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("send message failed with status %d: %s", resp.StatusCode, string(body))
}
// extractMessageContent extracts the message content from a SignalMessage
func extractMessageContent(msg *SignalMessage) *DataMessage {
if msg.Envelope.DataMessage.Timestamp != 0 || msg.Envelope.DataMessage.Message != "" {
return &msg.Envelope.DataMessage
}
if msg.Envelope.SyncMessage.SentMessage.Timestamp != 0 {
// Convert SentMessage to DataMessage-like struct
return &DataMessage{
Timestamp: msg.Envelope.SyncMessage.SentMessage.Timestamp,
Message: msg.Envelope.SyncMessage.SentMessage.Message,
ExpiresInSeconds: msg.Envelope.SyncMessage.SentMessage.ExpiresInSeconds,
ViewOnce: msg.Envelope.SyncMessage.SentMessage.ViewOnce,
Sticker: msg.Envelope.SyncMessage.SentMessage.Sticker,
GroupInfo: msg.Envelope.SyncMessage.SentMessage.GroupInfo,
}
}
return nil
}
// isMessageReaction checks if a message is a reaction
func isMessageReaction(msg *SignalMessage) bool {
if msg.Envelope.DataMessage.Reaction.Emoji != "" {
return true
}
if msg.Envelope.SyncMessage.Reaction.Emoji != "" {
return true
}
return false
}
// shouldCount determines if a message should be counted
func shouldCount(msg *SignalMessage) bool {
log.Println("shouldCount triggered")
// Check for sticker in data message
if msg.Envelope.DataMessage.Sticker.PackID != "" {
log.Printf("not counting because message has a sticker: %+v\n", msg)
return false
}
// Check for sticker in sync message
if msg.Envelope.SyncMessage.SentMessage.Sticker.PackID != "" {
log.Printf("not counting because message has a sticker: %+v\n", msg)
return false
}
log.Printf("counting message: %+v\n", msg)
return true
}
// containsString checks if a string is in a slice
func containsString(slice []string, str string) bool {
for _, s := range slice {
if s == str {
return true
}
}
return false
}
// triggerCommand handles command processing
func triggerCommand(messageContent *DataMessage, recipient string) {
commandMutex.Lock()
defer commandMutex.Unlock()
message := messageContent.Message
if message == "" || !strings.HasPrefix(message, "!") {
return
}
currentTime := time.Now()
if !lastCommandTime.IsZero() && currentTime.Sub(lastCommandTime) < 10*time.Second {
if !warningSent {
if err := sendMessage("BEEP BOOP POCZEKAJ 10 SEKUND.", recipient); err != nil {
log.Printf("Failed to send warning: %v\n", err)
}
warningSent = true
}
return
}
var base64Data string
var err error
if containsString(catCommands, message) {
base64Data, err = fetchCatImage()
if err != nil {
log.Printf("Failed to fetch cat image: %v\n", err)
if sendErr := sendMessage(fmt.Sprintf("trigger_command, error: %v", err), recipient); sendErr != nil {
log.Printf("Failed to send error message: %v\n", sendErr)
}
return
}
} else if containsString(dogCommands, message) {
base64Data, err = fetchDogImage()
if err != nil {
log.Printf("Failed to fetch dog image: %v\n", err)
if sendErr := sendMessage(fmt.Sprintf("trigger_command, error: %v", err), recipient); sendErr != nil {
log.Printf("Failed to send error message: %v\n", sendErr)
}
return
}
} else {
return // Unknown command
}
if err := sendImage(base64Data, recipient); err != nil {
log.Printf("Failed to send image: %v\n", err)
if sendErr := sendMessage(fmt.Sprintf("trigger_command, error: %v", err), recipient); sendErr != nil {
log.Printf("Failed to send error message: %v\n", sendErr)
}
return
}
lastCommandTime = currentTime
warningSent = false
}
// countMessages handles message counting
func countMessages(msg *SignalMessage, counter *StringCounter) {
if shouldCount(msg) {
uuid := msg.Envelope.SourceUuid
sourceName := msg.Envelope.SourceName
counter.UpdateStringMap(uuid, sourceName)
if err := sendMessage(counter.GetStringMapJSON(), phoneNumber); err != nil {
log.Printf("Failed to send count message: %v\n", err)
}
}
}
// sendToGroup handles messages from the specified group
func sendToGroup(msg *SignalMessage, messageContent *DataMessage, counter *StringCounter) {
if messageContent.GroupInfo.GroupID == groupID {
countMessages(msg, counter)
triggerCommand(messageContent, groupIDSend)
}
}
// removeAttachment removes an attachment by ID
func removeAttachment(attachmentID string) error {
req, err := http.NewRequest(http.MethodDelete, removeAttachmentURL+attachmentID, nil)
if err != nil {
return fmt.Errorf("failed to create delete request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to delete attachment: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNoContent {
log.Println("Attachment removed successfully.")
return nil
}
return fmt.Errorf("remove attachment failed with status: %d", resp.StatusCode)
}
// getAttachments fetches and removes all attachments
func getAttachments() error {
resp, err := http.Get(removeAttachmentURL)
if err != nil {
return fmt.Errorf("failed to get attachments: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("get attachments failed with status: %d", resp.StatusCode)
}
var attachments []string
if err := json.NewDecoder(resp.Body).Decode(&attachments); err != nil {
return fmt.Errorf("failed to decode attachments: %w", err)
}
log.Printf("attachments: %v\n", attachments)
for _, attachment := range attachments {
log.Printf("attachment: %s\n", attachment)
if err := removeAttachment(attachment); err != nil {
log.Printf("Failed to remove attachment %s: %v\n", attachment, err)
}
}
return nil
}
// scheduledTask runs at 21:37 daily
func scheduledTask(counter *StringCounter) {
for {
now := time.Now()
targetTime := time.Date(now.Year(), now.Month(), now.Day(), 21, 37, 0, 0, now.Location())
if now.After(targetTime) {
targetTime = targetTime.Add(24 * time.Hour)
}
waitDuration := targetTime.Sub(now)
log.Printf("Scheduled task will run in %v at %v\n", waitDuration, targetTime)
time.Sleep(waitDuration)
if err := sendMessage(counter.GetStringMapJSON(), groupIDSend); err != nil {
log.Printf("Failed to send scheduled message: %v\n", err)
}
counter.Reset()
}
}
// listenToServer connects to the WebSocket and listens for messages
func listenToServer(counter *StringCounter) {
uri := fmt.Sprintf("ws://localhost:9922/v1/receive/%s?send_read_receipts=false", phoneNumber)
for {
log.Println("Attempting to connect to Signal server...")
conn, _, err := websocket.DefaultDialer.Dial(uri, nil)
if err != nil {
log.Printf("Failed to connect to WebSocket: %v. Retrying in 5 seconds...\n", err)
time.Sleep(5 * time.Second)
continue
}
log.Println("Connected to Signal server")
for {
_, message, err := conn.ReadMessage()
if err != nil {
log.Printf("WebSocket read error: %v\n", err)
break
}
var msg SignalMessage
if err := json.Unmarshal(message, &msg); err != nil {
log.Printf("Failed to parse message: %v\n", err)
continue
}
if isMessageReaction(&msg) {
continue
}
log.Printf("message: %s\n", string(message))
messageContent := extractMessageContent(&msg)
if messageContent != nil {
sendToGroup(&msg, messageContent, counter)
}
}
conn.Close()
log.Println("Connection closed. Reconnecting in 5 seconds...")
time.Sleep(5 * time.Second)
}
}
func main() {
log.Println("Starting Signal Bot (Go version)...")
log.Printf("Phone Number: %s\n", phoneNumber)
log.Printf("Group ID: %s\n", groupID)
log.Printf("Group ID Send: %s\n", groupIDSend)
counter := NewStringCounter()
// Start the scheduled task in a goroutine
go scheduledTask(counter)
// Run the WebSocket listener (blocking)
listenToServer(counter)
}