aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarko Zajc <marko@zajc.eu.org>2022-07-02 03:50:02 +0200
committerMarko Zajc <marko@zajc.eu.org>2022-07-02 03:50:02 +0200
commite3222461c75071392592b5b335b9d0685989b3a4 (patch)
treec60504fc66641d73785c3a649bcef7ba70f9fac1
parent18b901852bb3bdedb2bccf753c7daf360d1d8bc8 (diff)
Fix a race condition in EventWaiterListener
-rw-r--r--src/main/java/libot/Main.java24
-rw-r--r--src/main/java/libot/core/entities/BotContext.java12
-rw-r--r--src/main/java/libot/core/entities/CommandContext.java2
-rw-r--r--src/main/java/libot/core/listeners/EventLogListener.java (renamed from src/main/java/libot/listeners/EventLogListener.java)2
-rw-r--r--src/main/java/libot/core/listeners/EventWaiterListener.java (renamed from src/main/java/libot/listeners/EventWaiterListener.java)41
-rw-r--r--src/main/java/libot/core/listeners/MessageListener.java (renamed from src/main/java/libot/listeners/MessageListener.java)2
-rw-r--r--src/main/java/libot/core/listeners/ShredClashListener.java (renamed from src/main/java/libot/listeners/ShredClashListener.java)2
-rw-r--r--src/main/java/libot/utils/EventUtils.java20
8 files changed, 66 insertions, 39 deletions
diff --git a/src/main/java/libot/Main.java b/src/main/java/libot/Main.java
index a65c7c6..bb43869 100644
--- a/src/main/java/libot/Main.java
+++ b/src/main/java/libot/Main.java
@@ -4,6 +4,7 @@ import static java.lang.Integer.parseInt;
4import static java.lang.Runtime.getRuntime; 4import static java.lang.Runtime.getRuntime;
5import static java.lang.System.getenv; 5import static java.lang.System.getenv;
6import static java.util.concurrent.TimeUnit.MINUTES; 6import static java.util.concurrent.TimeUnit.MINUTES;
7import static java.util.stream.Stream.concat;
7import static libot.core.Constants.*; 8import static libot.core.Constants.*;
8import static libot.core.processes.ProcessManager.getProcesses; 9import static libot.core.processes.ProcessManager.getProcesses;
9import static libot.utils.ReflectionUtils.scanClasspath; 10import static libot.utils.ReflectionUtils.scanClasspath;
@@ -14,6 +15,7 @@ import static org.slf4j.LoggerFactory.getLogger;
14 15
15import java.io.IOException; 16import java.io.IOException;
16import java.util.*; 17import java.util.*;
18import java.util.stream.Stream;
17 19
18import javax.annotation.Nonnull; 20import javax.annotation.Nonnull;
19import javax.security.auth.login.LoginException; 21import javax.security.auth.login.LoginException;
@@ -25,6 +27,7 @@ import libot.core.commands.CommandManager;
25import libot.core.data.DataManagerFactory; 27import libot.core.data.DataManagerFactory;
26import libot.core.data.providers.ProviderManager; 28import libot.core.data.providers.ProviderManager;
27import libot.core.entities.BotContext; 29import libot.core.entities.BotContext;
30import libot.core.listeners.*;
28import libot.core.processes.ProcessManager; 31import libot.core.processes.ProcessManager;
29import libot.core.shred.Shredder; 32import libot.core.shred.Shredder;
30import libot.core.shred.Shredder.Shred; 33import libot.core.shred.Shredder.Shred;
@@ -51,11 +54,14 @@ public class Main {
51 System.exit(1); 54 System.exit(1);
52 }); 55 });
53 LOG.info("Creating shreds"); 56 LOG.info("Creating shreds");
57 var ewl = new EventWaiterListener();
58
54 var builder = 59 var builder =
55 JDABuilder.create(GUILD_MEMBERS, GUILD_EMOJIS, GUILD_VOICE_STATES, GUILD_MESSAGES, GUILD_MESSAGE_REACTIONS) 60 JDABuilder.create(GUILD_MEMBERS, GUILD_EMOJIS, GUILD_VOICE_STATES, GUILD_MESSAGES, GUILD_MESSAGE_REACTIONS)
56 .enableCache(VOICE_STATE, EMOTE, MEMBER_OVERRIDES) 61 .enableCache(VOICE_STATE, EMOTE, MEMBER_OVERRIDES)
57 .disableCache(ACTIVITY, CLIENT_STATUS, ONLINE_STATUS) 62 .disableCache(ACTIVITY, CLIENT_STATUS, ONLINE_STATUS)
58 .setChunkingFilter(ChunkingFilter.ALL) 63 .setChunkingFilter(ChunkingFilter.ALL)
64 .addEventListeners(ewl)
59 .setStatus(IDLE); 65 .setStatus(IDLE);
60 66
61 var shreds = startShreds(builder); 67 var shreds = startShreds(builder);
@@ -68,7 +74,7 @@ public class Main {
68 var providers = ProviderManager.fromClasspath(shredder, data); 74 var providers = ProviderManager.fromClasspath(shredder, data);
69 var config = BotConfiguration.fromEnvironment(); 75 var config = BotConfiguration.fromEnvironment();
70 var commands = CommandManager.fromClasspath(); 76 var commands = CommandManager.fromClasspath();
71 var bot = new BotContext(config, commands, data, shredder, providers); 77 var bot = new BotContext(config, commands, data, shredder, providers, ewl);
72 78
73 bot.cron().scheduleWithFixedDelay(providers::storeAll, 2, 2, MINUTES); 79 bot.cron().scheduleWithFixedDelay(providers::storeAll, 2, 2, MINUTES);
74 getRuntime().addShutdownHook(new Thread(() -> stop(bot), "libot-shutdown")); 80 getRuntime().addShutdownHook(new Thread(() -> stop(bot), "libot-shutdown"));
@@ -147,13 +153,15 @@ public class Main {
147 153
148 @SuppressWarnings("null") 154 @SuppressWarnings("null")
149 private static void loadEventListeners(@Nonnull Shredder shredder, @Nonnull BotContext bot) { 155 private static void loadEventListeners(@Nonnull Shredder shredder, @Nonnull BotContext bot) {
150 var listeners = scanClasspath(EventListener.class, libot.listeners.Anchor.class, c -> { 156
151 try { 157 var listeners = concat(Stream.of(new MessageListener(bot), new ShredClashListener(bot), new EventLogListener()),
152 return c.getDeclaredConstructor(BotContext.class).newInstance(bot); 158 scanClasspath(EventListener.class, libot.listeners.Anchor.class, c -> {
153 } catch (NoSuchMethodException e) { 159 try {
154 return c.getDeclaredConstructor().newInstance(); 160 return c.getDeclaredConstructor(BotContext.class).newInstance(bot);
155 } 161 } catch (NoSuchMethodException e) {
156 }).toArray(); 162 return c.getDeclaredConstructor().newInstance();
163 }
164 }).stream()).toArray();
157 shredder.getShreds().stream().map(Shred::jda).forEach(j -> j.addEventListener(listeners)); 165 shredder.getShreds().stream().map(Shred::jda).forEach(j -> j.addEventListener(listeners));
158 } 166 }
159 167
diff --git a/src/main/java/libot/core/entities/BotContext.java b/src/main/java/libot/core/entities/BotContext.java
index 25c526b..b211d2e 100644
--- a/src/main/java/libot/core/entities/BotContext.java
+++ b/src/main/java/libot/core/entities/BotContext.java
@@ -10,6 +10,7 @@ import libot.core.BotConfiguration;
10import libot.core.commands.CommandManager; 10import libot.core.commands.CommandManager;
11import libot.core.data.DataManager; 11import libot.core.data.DataManager;
12import libot.core.data.providers.*; 12import libot.core.data.providers.*;
13import libot.core.listeners.EventWaiterListener;
13import libot.core.shred.Shredder; 14import libot.core.shred.Shredder;
14 15
15public class BotContext { 16public class BotContext {
@@ -26,14 +27,18 @@ public class BotContext {
26 private final Shredder shredder; 27 private final Shredder shredder;
27 @Nonnull 28 @Nonnull
28 private final ProviderManager providers; 29 private final ProviderManager providers;
30 @Nonnull
31 private final EventWaiterListener ewl;
29 32
30 public BotContext(@Nonnull BotConfiguration config, @Nonnull CommandManager commands, @Nonnull DataManager data, 33 public BotContext(@Nonnull BotConfiguration config, @Nonnull CommandManager commands, @Nonnull DataManager data,
31 @Nonnull Shredder shredder, @Nonnull ProviderManager providers) { 34 @Nonnull Shredder shredder, @Nonnull ProviderManager providers,
35 @Nonnull EventWaiterListener ewl) {
32 this.config = config; 36 this.config = config;
33 this.commands = commands; 37 this.commands = commands;
34 this.data = data; 38 this.data = data;
35 this.shredder = shredder; 39 this.shredder = shredder;
36 this.providers = providers; 40 this.providers = providers;
41 this.ewl = ewl;
37 } 42 }
38 43
39 @Nonnull 44 @Nonnull
@@ -67,6 +72,11 @@ public class BotContext {
67 } 72 }
68 73
69 @Nonnull 74 @Nonnull
75 public EventWaiterListener ewl() {
76 return this.ewl;
77 }
78
79 @Nonnull
70 public <T extends Provider<?>> T provider(@Nonnull Class<T> clazz) { 80 public <T extends Provider<?>> T provider(@Nonnull Class<T> clazz) {
71 return this.providers.get(clazz); 81 return this.providers.get(clazz);
72 } 82 }
diff --git a/src/main/java/libot/core/entities/CommandContext.java b/src/main/java/libot/core/entities/CommandContext.java
index 4d286b9..042e9bb 100644
--- a/src/main/java/libot/core/entities/CommandContext.java
+++ b/src/main/java/libot/core/entities/CommandContext.java
@@ -787,7 +787,7 @@ public class CommandContext {
787 787
788 public EventUtils getWaiter() { 788 public EventUtils getWaiter() {
789 if (!isWaiterInited()) 789 if (!isWaiterInited())
790 this.waiter = new EventUtils(getUser(), getChannel()); 790 this.waiter = new EventUtils(getBotContext().ewl(), getUser(), getChannel());
791 return this.waiter; 791 return this.waiter;
792 } 792 }
793 793
diff --git a/src/main/java/libot/listeners/EventLogListener.java b/src/main/java/libot/core/listeners/EventLogListener.java
index dceb635..fa95138 100644
--- a/src/main/java/libot/listeners/EventLogListener.java
+++ b/src/main/java/libot/core/listeners/EventLogListener.java
@@ -1,4 +1,4 @@
1package libot.listeners; 1package libot.core.listeners;
2 2
3import static org.slf4j.LoggerFactory.getLogger; 3import static org.slf4j.LoggerFactory.getLogger;
4 4
diff --git a/src/main/java/libot/listeners/EventWaiterListener.java b/src/main/java/libot/core/listeners/EventWaiterListener.java
index a2e32e7..e4aad6c 100644
--- a/src/main/java/libot/listeners/EventWaiterListener.java
+++ b/src/main/java/libot/core/listeners/EventWaiterListener.java
@@ -1,4 +1,4 @@
1package libot.listeners; 1package libot.core.listeners;
2 2
3import static java.lang.String.format; 3import static java.lang.String.format;
4import static java.lang.Thread.interrupted; 4import static java.lang.Thread.interrupted;
@@ -23,17 +23,18 @@ public class EventWaiterListener implements EventListener {
23 private static final String FORMAT_WRONG_TYPE = "Something went wrong; onEvent() returned %s instead of %s"; 23 private static final String FORMAT_WRONG_TYPE = "Something went wrong; onEvent() returned %s instead of %s";
24 24
25 // @formatter:off 💀 25 // @formatter:off 💀
26 private static final MutableIntObjectMap< 26 private final MutableIntObjectMap<
27 Triple< 27 Triple<
28 Predicate<GenericEvent>, 28 Predicate<GenericEvent>,
29 MessageLock<GenericEvent>, 29 MessageLock<GenericEvent>,
30 Predicate<Void> 30 Predicate<Void>
31 > 31 >
32 > EVENT_WAITERS = 32 > eventWaiters =
33 // @formatter:on 33 // @formatter:on
34 IntObjectMaps.mutable.<Triple<Predicate<GenericEvent>, MessageLock<GenericEvent>, Predicate<Void>>>empty() 34 IntObjectMaps.mutable.<Triple<Predicate<GenericEvent>, MessageLock<GenericEvent>, Predicate<Void>>>empty(); // death
35 .asSynchronized(); // death 35
36 private static final AtomicInteger COUNTER = new AtomicInteger(); 36 private final AtomicInteger counter = new AtomicInteger();
37 private final Object mutex = new Object();
37 38
38 /** 39 /**
39 * Pauses the current thread and awaits a certain event.<br> 40 * Pauses the current thread and awaits a certain event.<br>
@@ -63,21 +64,25 @@ public class EventWaiterListener implements EventListener {
63 * @throws InterruptedException 64 * @throws InterruptedException
64 */ 65 */
65 @SuppressWarnings("unchecked") 66 @SuppressWarnings("unchecked")
66 public static <T extends GenericEvent> T awaitEvent(@Nonnull Predicate<GenericEvent> predicate, 67 public <T extends GenericEvent> T awaitEvent(@Nonnull Predicate<GenericEvent> predicate,
67 @Nullable Predicate<Void> nullableCleanupPredicate, int timeout, 68 @Nullable Predicate<Void> nullableCleanupPredicate, int timeout,
68 @Nullable TimeUnit timeoutUnit, 69 @Nullable TimeUnit timeoutUnit,
69 @Nonnull Class<T> eventClass) throws TimeoutException, 70 @Nonnull Class<T> eventClass) throws TimeoutException,
70 InterruptedException { 71 InterruptedException {
71 var cleanupPredicate = nullableCleanupPredicate; 72 var cleanupPredicate = nullableCleanupPredicate;
72 if (cleanupPredicate == null) 73 if (cleanupPredicate == null)
73 cleanupPredicate = p -> false; 74 cleanupPredicate = p -> false;
74 75
75 var lock = new MessageLock<GenericEvent>(); 76 var lock = new MessageLock<GenericEvent>();
76 int ticket = COUNTER.getAndIncrement(); 77 int ticket = this.counter.getAndIncrement();
77 Predicate<GenericEvent> isInstance = eventClass::isInstance; 78 Predicate<GenericEvent> isInstance = eventClass::isInstance;
78 EVENT_WAITERS.put(ticket, new ImmutableTriple<>(isInstance.and(predicate), lock, cleanupPredicate)); 79 synchronized (this.mutex) {
80 this.eventWaiters.put(ticket, new ImmutableTriple<>(isInstance.and(predicate), lock, cleanupPredicate));
81 }
79 GenericEvent awaited = lock.receive(timeout, timeoutUnit); 82 GenericEvent awaited = lock.receive(timeout, timeoutUnit);
80 EVENT_WAITERS.remove(ticket); 83 synchronized (this.mutex) {
84 this.eventWaiters.remove(ticket);
85 }
81 sanityCheck(eventClass, awaited); 86 sanityCheck(eventClass, awaited);
82 return (T) awaited; 87 return (T) awaited;
83 } 88 }
@@ -97,9 +102,11 @@ public class EventWaiterListener implements EventListener {
97 102
98 @Override 103 @Override
99 public void onEvent(GenericEvent event) { 104 public void onEvent(GenericEvent event) {
100 for (var e : EVENT_WAITERS.values()) { 105 synchronized (this.mutex) {
101 if (e.getLeft().test(event)) 106 for (var e : this.eventWaiters.values()) {
102 e.getMiddle().send(event); 107 if (e.getLeft().test(event))
108 e.getMiddle().send(event);
109 }
103 } 110 }
104 } 111 }
105 112
diff --git a/src/main/java/libot/listeners/MessageListener.java b/src/main/java/libot/core/listeners/MessageListener.java
index 00fa2e4..8bccece 100644
--- a/src/main/java/libot/listeners/MessageListener.java
+++ b/src/main/java/libot/core/listeners/MessageListener.java
@@ -1,4 +1,4 @@
1package libot.listeners; 1package libot.core.listeners;
2 2
3import static libot.utils.ParseUtils.parseCommandName; 3import static libot.utils.ParseUtils.parseCommandName;
4 4
diff --git a/src/main/java/libot/listeners/ShredClashListener.java b/src/main/java/libot/core/listeners/ShredClashListener.java
index 93b1c19..a54cbcb 100644
--- a/src/main/java/libot/listeners/ShredClashListener.java
+++ b/src/main/java/libot/core/listeners/ShredClashListener.java
@@ -1,4 +1,4 @@
1package libot.listeners; 1package libot.core.listeners;
2 2
3import javax.annotation.Nonnull; 3import javax.annotation.Nonnull;
4 4
diff --git a/src/main/java/libot/utils/EventUtils.java b/src/main/java/libot/utils/EventUtils.java
index a5b44a4..85687cb 100644
--- a/src/main/java/libot/utils/EventUtils.java
+++ b/src/main/java/libot/utils/EventUtils.java
@@ -5,7 +5,7 @@ import static libot.core.Constants.*;
5 5
6import javax.annotation.Nonnull; 6import javax.annotation.Nonnull;
7 7
8import libot.listeners.EventWaiterListener; 8import libot.core.listeners.EventWaiterListener;
9import net.dv8tion.jda.api.entities.*; 9import net.dv8tion.jda.api.entities.*;
10import net.dv8tion.jda.api.events.message.MessageReceivedEvent; 10import net.dv8tion.jda.api.events.message.MessageReceivedEvent;
11import net.dv8tion.jda.api.events.message.react.*; 11import net.dv8tion.jda.api.events.message.react.*;
@@ -15,20 +15,22 @@ public class EventUtils {
15 private int timeout; 15 private int timeout;
16 private final User user; 16 private final User user;
17 private final MessageChannel channel; 17 private final MessageChannel channel;
18 private final EventWaiterListener ewl;
18 19
19 public EventUtils(User user, MessageChannel channel) { 20 public EventUtils(EventWaiterListener ewl, User user, MessageChannel channel) {
20 this.user = user; 21 this.user = user;
21 this.channel = channel; 22 this.channel = channel;
23 this.ewl = ewl;
22 } 24 }
23 25
24 public EventUtils(User user, MessageChannel channel, int timeout) { 26 public EventUtils(EventWaiterListener ewl, User user, MessageChannel channel, int timeout) {
25 this(user, channel); 27 this(ewl, user, channel);
26 this.timeout = timeout; 28 this.timeout = timeout;
27 } 29 }
28 30
29 @Nonnull 31 @Nonnull
30 public MessageReaction getReaction(Message message) throws InterruptedException { 32 public MessageReaction getReaction(Message message) throws InterruptedException {
31 return EventWaiterListener.awaitEvent(p -> { 33 return this.ewl.awaitEvent(p -> {
32 GenericMessageReactionEvent e = (GenericMessageReactionEvent) p; 34 GenericMessageReactionEvent e = (GenericMessageReactionEvent) p;
33 return e.getUserIdLong() == this.user.getIdLong() && e.getMessageIdLong() == message.getIdLong(); 35 return e.getUserIdLong() == this.user.getIdLong() && e.getMessageIdLong() == message.getIdLong();
34 }, p -> { 36 }, p -> {
@@ -41,10 +43,10 @@ public class EventUtils {
41 }, this.timeout, SECONDS, GenericMessageReactionEvent.class).getReaction(); 43 }, this.timeout, SECONDS, GenericMessageReactionEvent.class).getReaction();
42 } 44 }
43 45
44 @SuppressWarnings("null")
45 @Nonnull 46 @Nonnull
47 @SuppressWarnings("null")
46 public Message getMessage(String emoji) throws InterruptedException { 48 public Message getMessage(String emoji) throws InterruptedException {
47 GenericMessageReactionEvent event = EventWaiterListener.awaitEvent(p -> { 49 GenericMessageReactionEvent event = this.ewl.awaitEvent(p -> {
48 50
49 GenericMessageReactionEvent e = (GenericMessageReactionEvent) p; 51 GenericMessageReactionEvent e = (GenericMessageReactionEvent) p;
50 52
@@ -58,7 +60,7 @@ public class EventUtils {
58 60
59 @Nonnull 61 @Nonnull
60 public Message awaitMessage(boolean ignoreBlank) throws InterruptedException { 62 public Message awaitMessage(boolean ignoreBlank) throws InterruptedException {
61 return EventWaiterListener.awaitEvent(p -> { 63 return this.ewl.awaitEvent(p -> {
62 64
63 var e = (MessageReceivedEvent) p; 65 var e = (MessageReceivedEvent) p;
64 66
@@ -70,7 +72,7 @@ public class EventUtils {
70 } 72 }
71 73
72 public boolean awaitBoolean(@Nonnull Message question) throws InterruptedException { 74 public boolean awaitBoolean(@Nonnull Message question) throws InterruptedException {
73 boolean result = ACCEPT_EMOJI.equals(EventWaiterListener.awaitEvent(p -> { 75 boolean result = ACCEPT_EMOJI.equals(this.ewl.awaitEvent(p -> {
74 MessageReactionAddEvent e = (MessageReactionAddEvent) p; 76 MessageReactionAddEvent e = (MessageReactionAddEvent) p;
75 String emote = e.getReactionEmote().getName(); 77 String emote = e.getReactionEmote().getName();
76 78