Fix anomaly detectors

This commit is contained in:
kacperlo 2025-06-10 15:49:31 +02:00
parent 767de2e643
commit 06f79923bb

View File

@ -21,10 +21,16 @@ import org.apache.flink.streaming.api.windowing.assigners.SlidingProcessingTimeW
import org.apache.flink.streaming.api.windowing.time.Time; import org.apache.flink.streaming.api.windowing.time.Time;
import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.util.Collector; import org.apache.flink.util.Collector;
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.configuration.Configuration;
import java.util.*; import java.util.*;
import java.time.Instant; import java.time.Instant;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable;
public class AnomalyDetector { public class AnomalyDetector {
@ -81,13 +87,13 @@ public class AnomalyDetector {
// 1. Amount anomaly - sudden high-value transactions // 1. Amount anomaly - sudden high-value transactions
DataStream<TransactionAlert> amountAlerts = transactionStream DataStream<TransactionAlert> amountAlerts = transactionStream
.keyBy(Transaction::getCardId) .keyBy(Transaction::getCardId)
.window(SlidingProcessingTimeWindows.of(Time.minutes(10), Time.minutes(1))) .window(SlidingProcessingTimeWindows.of(Time.minutes(5), Time.minutes(1)))
.process(new AmountAnomalyDetector()); .process(new AmountAnomalyDetector());
// 2. Location anomaly - sudden change in location // 2. Location anomaly - sudden change in location
DataStream<TransactionAlert> locationAlerts = transactionStream DataStream<TransactionAlert> locationAlerts = transactionStream
.keyBy(Transaction::getCardId) .keyBy(Transaction::getCardId)
.window(SlidingProcessingTimeWindows.of(Time.minutes(10), Time.minutes(1))) .window(SlidingProcessingTimeWindows.of(Time.minutes(5), Time.minutes(1)))
.process(new LocationAnomalyDetector()); .process(new LocationAnomalyDetector());
// 3. Frequency anomaly - unusual number of transactions in short time // 3. Frequency anomaly - unusual number of transactions in short time
@ -120,28 +126,28 @@ public class AnomalyDetector {
} }
// Detector for unusual transaction amounts // Detector for unusual transaction amounts
public static class AmountAnomalyDetector public static class AmountAnomalyDetector
extends ProcessWindowFunction<Transaction, TransactionAlert, String, TimeWindow> { extends ProcessWindowFunction<Transaction, TransactionAlert, String, TimeWindow> {
@Override @Override
public void process(String cardId, Context context, Iterable<Transaction> transactions, public void process(String cardId, Context context, Iterable<Transaction> transactions,
Collector<TransactionAlert> out) { Collector<TransactionAlert> out) {
List<Transaction> transactionList = new ArrayList<>(); List<Transaction> transactionList = new ArrayList<>();
transactions.forEach(transactionList::add); transactions.forEach(transactionList::add);
if (transactionList.isEmpty()) return; if (transactionList.isEmpty()) return;
// Calculate statistics // Calculate statistics
double averageAmount = transactionList.stream() double averageAmount = transactionList.stream()
.mapToDouble(Transaction::getAmount) .mapToDouble(Transaction::getAmount)
.average() .average()
.orElse(0); .orElse(0);
double stdDeviation = calculateStdDeviation(transactionList, averageAmount); double stdDeviation = calculateStdDeviation(transactionList, averageAmount);
// Check for anomalies (transactions that are more than 3 standard deviations from mean) // Check for anomalies (transactions that are more than 1.7 standard deviations from mean)
for (Transaction transaction : transactionList) { for (Transaction transaction : transactionList) {
if (stdDeviation > 0 && Math.abs(transaction.getAmount() - averageAmount) > 3 * stdDeviation) { if (stdDeviation > 0 && Math.abs(transaction.getAmount() - averageAmount) > 2 * stdDeviation && transaction.getAmount() > averageAmount && transaction.getAmount() > 1000) {
out.collect(new TransactionAlert( out.collect(new TransactionAlert(
"AMOUNT_ANOMALY", "AMOUNT_ANOMALY",
transaction.getCardId(), transaction.getCardId(),
@ -150,13 +156,13 @@ public class AnomalyDetector {
transaction.getLatitude(), transaction.getLatitude(),
transaction.getLongitude(), transaction.getLongitude(),
transaction.getTimestamp(), transaction.getTimestamp(),
"Unusual transaction amount detected: $" + transaction.getAmount() + "Unusual transaction amount detected: $" + transaction.getAmount() +
" (Average: $" + String.format("%.2f", averageAmount) + ")" " (Average: $" + String.format("%.2f", averageAmount) + ")"
)); ));
} }
} }
} }
private double calculateStdDeviation(List<Transaction> transactions, double mean) { private double calculateStdDeviation(List<Transaction> transactions, double mean) {
return Math.sqrt(transactions.stream() return Math.sqrt(transactions.stream()
.mapToDouble(t -> Math.pow(t.getAmount() - mean, 2)) .mapToDouble(t -> Math.pow(t.getAmount() - mean, 2))
@ -166,87 +172,153 @@ public class AnomalyDetector {
} }
// Detector for unusual transaction locations // Detector for unusual transaction locations
public static class LocationAnomalyDetector public static class LocationAnomalyDetector
extends ProcessWindowFunction<Transaction, TransactionAlert, String, TimeWindow> { extends ProcessWindowFunction<Transaction, TransactionAlert, String, TimeWindow> {
// Map to store frequent locations for each card private transient MapState<String, Set<LocationPoint>> knownLocations;
private final Map<String, Set<LocationPoint>> cardLocations = new HashMap<>(); private static final int MAX_KNOWN_LOCATIONS = 5; // Limit known locations to avoid memory issues
private static final double ANOMALY_DISTANCE_THRESHOLD = 50.0; // Threshold in km
private static final int MIN_LOCATIONS_FOR_DETECTION = 3; // Minimum known locations before detecting anomalies
@Override @Override
public void process(String cardId, Context context, Iterable<Transaction> transactions, public void open(Configuration parameters) throws Exception {
Collector<TransactionAlert> out) { MapStateDescriptor<String, Set<LocationPoint>> descriptor =
new MapStateDescriptor<>(
"knownLocations",
TypeInformation.of(String.class),
TypeInformation.of(new TypeHint<Set<LocationPoint>>() {})
);
knownLocations = getRuntimeContext().getMapState(descriptor);
}
@Override
public void process(String cardId, Context context, Iterable<Transaction> transactions,
Collector<TransactionAlert> out) throws Exception {
List<Transaction> transactionList = new ArrayList<>(); List<Transaction> transactionList = new ArrayList<>();
transactions.forEach(transactionList::add); transactions.forEach(transactionList::add);
if (transactionList.isEmpty()) return; if (transactionList.isEmpty()) return;
// Get or create location set for this card // Get or create location set for this card
Set<LocationPoint> frequentLocations = cardLocations.computeIfAbsent(cardId, k -> new HashSet<>()); Set<LocationPoint> cardKnownLocations;
if (knownLocations.contains(cardId)) {
cardKnownLocations = knownLocations.get(cardId);
System.out.println("Card " + cardId + " has " + cardKnownLocations.size() + " known locations");
} else {
cardKnownLocations = new HashSet<>();
System.out.println("New card detected: " + cardId + ", initializing known locations");
}
// Process each transaction // Process each transaction
for (Transaction transaction : transactionList) { for (Transaction transaction : transactionList) {
LocationPoint currentPoint = new LocationPoint(transaction.getLatitude(), transaction.getLongitude()); LocationPoint currentPoint = new LocationPoint(transaction.getLatitude(), transaction.getLongitude());
// If we have at least 3 frequent locations for this card // First few transactions establish the baseline locations
if (frequentLocations.size() >= 3) { if (cardKnownLocations.size() < MIN_LOCATIONS_FOR_DETECTION) {
boolean isNearKnownLocation = false; System.out.println("Building baseline for card " + cardId + ", adding location #" +
(cardKnownLocations.size() + 1) + " to known locations");
// Check if current location is near any known frequent location
for (LocationPoint knownPoint : frequentLocations) { // Check if this location is already very close to a known location before adding
if (calculateDistance(currentPoint, knownPoint) < 50) { // Less than 50km boolean isVeryCloseToKnown = false;
isNearKnownLocation = true; for (LocationPoint knownPoint : cardKnownLocations) {
if (calculateDistance(currentPoint, knownPoint) < 2.0) { // Within 2km = same area
isVeryCloseToKnown = true;
System.out.println("Location is very close to existing baseline location, not adding duplicate");
break; break;
} }
} }
// If not near any known location, it might be an anomaly // Only add distinct baseline locations
if (!isNearKnownLocation) { if (!isVeryCloseToKnown) {
out.collect(new TransactionAlert( cardKnownLocations.add(currentPoint);
"LOCATION_ANOMALY", }
transaction.getCardId(),
transaction.getUserId(), // We're still building the baseline, don't check for anomalies yet
transaction.getAmount(), continue;
transaction.getLatitude(), }
transaction.getLongitude(),
transaction.getTimestamp(), // Check distance to known locations
"Unusual transaction location detected at: " + double closestDistance = Double.MAX_VALUE;
transaction.getLatitude() + ", " + transaction.getLongitude() LocationPoint closestPoint = null;
));
for (LocationPoint knownPoint : cardKnownLocations) {
double distance = calculateDistance(currentPoint, knownPoint);
if (distance < closestDistance) {
closestDistance = distance;
closestPoint = knownPoint;
} }
} }
// Add current location to frequent locations (max 10 locations per card) System.out.println("CARD " + cardId + ": Transaction at " + currentPoint + ", closest known location: " +
if (frequentLocations.size() < 10) { closestPoint + " (" + String.format("%.2f", closestDistance) + " km)");
frequentLocations.add(currentPoint);
// Detect anomaly if transaction is far from all known locations
if (closestDistance > ANOMALY_DISTANCE_THRESHOLD) {
System.out.println("⚠️ LOCATION ANOMALY DETECTED: Distance " +
String.format("%.2f", closestDistance) + "km exceeds threshold of " +
ANOMALY_DISTANCE_THRESHOLD + "km");
out.collect(new TransactionAlert(
"LOCATION_ANOMALY",
transaction.getCardId(),
transaction.getUserId(),
transaction.getAmount(),
transaction.getLatitude(),
transaction.getLongitude(),
transaction.getTimestamp(),
"Unusual transaction location: " + String.format("%.2f", closestDistance) +
"km from nearest known location"
));
// Don't automatically add anomalous locations to known locations
} else {
// Check if this location is already very close to a known location
boolean isVeryCloseToKnown = false;
for (LocationPoint knownPoint : cardKnownLocations) {
if (calculateDistance(currentPoint, knownPoint) < 2.0) { // Within 2km = same area
isVeryCloseToKnown = true;
break;
}
}
// Only add distinct new locations, up to our maximum
if (!isVeryCloseToKnown && cardKnownLocations.size() < MAX_KNOWN_LOCATIONS) {
cardKnownLocations.add(currentPoint);
System.out.println("Added new location to known locations: " + currentPoint);
}
} }
} }
// Update the state
knownLocations.put(cardId, cardKnownLocations);
} }
// Calculate distance between two points using Haversine formula (in km) // Calculate distance between two points using Haversine formula (in km)
private double calculateDistance(LocationPoint p1, LocationPoint p2) { private double calculateDistance(LocationPoint p1, LocationPoint p2) {
final int R = 6371; // Earth radius in km final int R = 6371; // Earth radius in km
double latDistance = Math.toRadians(p2.latitude - p1.latitude); double latDistance = Math.toRadians(p2.latitude - p1.latitude);
double lonDistance = Math.toRadians(p2.longitude - p1.longitude); double lonDistance = Math.toRadians(p2.longitude - p1.longitude);
double a = Math.sin(latDistance / 2) * Math.sin(latDistance / 2) double a = Math.sin(latDistance / 2) * Math.sin(latDistance / 2)
+ Math.cos(Math.toRadians(p1.latitude)) * Math.cos(Math.toRadians(p2.latitude)) + Math.cos(Math.toRadians(p1.latitude)) * Math.cos(Math.toRadians(p2.latitude))
* Math.sin(lonDistance / 2) * Math.sin(lonDistance / 2); * Math.sin(lonDistance / 2) * Math.sin(lonDistance / 2);
double c = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a)); double c = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a));
return R * c; return R * c;
} }
private static class LocationPoint { private static class LocationPoint implements Serializable {
private static final long serialVersionUID = 1L;
private final double latitude; private final double latitude;
private final double longitude; private final double longitude;
public LocationPoint(double latitude, double longitude) { public LocationPoint(double latitude, double longitude) {
this.latitude = latitude; this.latitude = latitude;
this.longitude = longitude; this.longitude = longitude;
} }
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) return true; if (this == o) return true;
@ -255,35 +327,43 @@ public class AnomalyDetector {
return Double.compare(that.latitude, latitude) == 0 && return Double.compare(that.latitude, latitude) == 0 &&
Double.compare(that.longitude, longitude) == 0; Double.compare(that.longitude, longitude) == 0;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(latitude, longitude); return Objects.hash(latitude, longitude);
} }
@Override
public String toString() {
return "LocationPoint{" +
"lat=" + latitude +
", lon=" + longitude +
'}';
}
} }
} }
// Detector for unusual transaction frequency // Detector for unusual transaction frequency
public static class FrequencyAnomalyDetector public static class FrequencyAnomalyDetector
extends ProcessWindowFunction<Transaction, TransactionAlert, String, TimeWindow> { extends ProcessWindowFunction<Transaction, TransactionAlert, String, TimeWindow> {
@Override @Override
public void process(String cardId, Context context, Iterable<Transaction> transactions, public void process(String cardId, Context context, Iterable<Transaction> transactions,
Collector<TransactionAlert> out) { Collector<TransactionAlert> out) {
List<Transaction> transactionList = new ArrayList<>(); List<Transaction> transactionList = new ArrayList<>();
transactions.forEach(transactionList::add); transactions.forEach(transactionList::add);
// Get window info // Get window info
long windowStart = context.window().getStart(); long windowStart = context.window().getStart();
long windowEnd = context.window().getEnd(); long windowEnd = context.window().getEnd();
long windowSizeMinutes = (windowEnd - windowStart) / (1000 * 60); long windowSizeMinutes = (windowEnd - windowStart) / (1000 * 60);
// If there are more than 5 transactions in 5 minutes for the same card, flag it // If there are more than 5 transactions in 5 minutes for the same card, flag it
if (transactionList.size() > 5) { if (transactionList.size() > 5) {
Transaction latestTransaction = transactionList.stream() Transaction latestTransaction = transactionList.stream()
.max(Comparator.comparing(Transaction::getTimestamp)) .max(Comparator.comparing(Transaction::getTimestamp))
.orElse(transactionList.get(0)); .orElse(transactionList.get(0));
out.collect(new TransactionAlert( out.collect(new TransactionAlert(
"FREQUENCY_ANOMALY", "FREQUENCY_ANOMALY",
latestTransaction.getCardId(), latestTransaction.getCardId(),
@ -292,7 +372,7 @@ public class AnomalyDetector {
latestTransaction.getLatitude(), latestTransaction.getLatitude(),
latestTransaction.getLongitude(), latestTransaction.getLongitude(),
latestTransaction.getTimestamp(), latestTransaction.getTimestamp(),
"Unusual transaction frequency detected: " + transactionList.size() + "Unusual transaction frequency detected: " + transactionList.size() +
" transactions in " + windowSizeMinutes + " minutes" " transactions in " + windowSizeMinutes + " minutes"
)); ));
} }