diff options
author | Marko Zajc <marko@zajc.eu.org> | 2022-07-02 03:50:02 +0200 |
---|---|---|
committer | Marko Zajc <marko@zajc.eu.org> | 2022-07-02 03:50:02 +0200 |
commit | e3222461c75071392592b5b335b9d0685989b3a4 (patch) | |
tree | c60504fc66641d73785c3a649bcef7ba70f9fac1 | |
parent | 18b901852bb3bdedb2bccf753c7daf360d1d8bc8 (diff) |
Fix a race condition in EventWaiterListener
-rw-r--r-- | src/main/java/libot/Main.java | 24 | ||||
-rw-r--r-- | src/main/java/libot/core/entities/BotContext.java | 12 | ||||
-rw-r--r-- | src/main/java/libot/core/entities/CommandContext.java | 2 | ||||
-rw-r--r-- | src/main/java/libot/core/listeners/EventLogListener.java (renamed from src/main/java/libot/listeners/EventLogListener.java) | 2 | ||||
-rw-r--r-- | src/main/java/libot/core/listeners/EventWaiterListener.java (renamed from src/main/java/libot/listeners/EventWaiterListener.java) | 41 | ||||
-rw-r--r-- | src/main/java/libot/core/listeners/MessageListener.java (renamed from src/main/java/libot/listeners/MessageListener.java) | 2 | ||||
-rw-r--r-- | src/main/java/libot/core/listeners/ShredClashListener.java (renamed from src/main/java/libot/listeners/ShredClashListener.java) | 2 | ||||
-rw-r--r-- | src/main/java/libot/utils/EventUtils.java | 20 |
8 files changed, 66 insertions, 39 deletions
diff --git a/src/main/java/libot/Main.java b/src/main/java/libot/Main.java index a65c7c6..bb43869 100644 --- a/src/main/java/libot/Main.java +++ b/src/main/java/libot/Main.java | |||
@@ -4,6 +4,7 @@ import static java.lang.Integer.parseInt; | |||
4 | import static java.lang.Runtime.getRuntime; | 4 | import static java.lang.Runtime.getRuntime; |
5 | import static java.lang.System.getenv; | 5 | import static java.lang.System.getenv; |
6 | import static java.util.concurrent.TimeUnit.MINUTES; | 6 | import static java.util.concurrent.TimeUnit.MINUTES; |
7 | import static java.util.stream.Stream.concat; | ||
7 | import static libot.core.Constants.*; | 8 | import static libot.core.Constants.*; |
8 | import static libot.core.processes.ProcessManager.getProcesses; | 9 | import static libot.core.processes.ProcessManager.getProcesses; |
9 | import static libot.utils.ReflectionUtils.scanClasspath; | 10 | import static libot.utils.ReflectionUtils.scanClasspath; |
@@ -14,6 +15,7 @@ import static org.slf4j.LoggerFactory.getLogger; | |||
14 | 15 | ||
15 | import java.io.IOException; | 16 | import java.io.IOException; |
16 | import java.util.*; | 17 | import java.util.*; |
18 | import java.util.stream.Stream; | ||
17 | 19 | ||
18 | import javax.annotation.Nonnull; | 20 | import javax.annotation.Nonnull; |
19 | import javax.security.auth.login.LoginException; | 21 | import javax.security.auth.login.LoginException; |
@@ -25,6 +27,7 @@ import libot.core.commands.CommandManager; | |||
25 | import libot.core.data.DataManagerFactory; | 27 | import libot.core.data.DataManagerFactory; |
26 | import libot.core.data.providers.ProviderManager; | 28 | import libot.core.data.providers.ProviderManager; |
27 | import libot.core.entities.BotContext; | 29 | import libot.core.entities.BotContext; |
30 | import libot.core.listeners.*; | ||
28 | import libot.core.processes.ProcessManager; | 31 | import libot.core.processes.ProcessManager; |
29 | import libot.core.shred.Shredder; | 32 | import libot.core.shred.Shredder; |
30 | import libot.core.shred.Shredder.Shred; | 33 | import libot.core.shred.Shredder.Shred; |
@@ -51,11 +54,14 @@ public class Main { | |||
51 | System.exit(1); | 54 | System.exit(1); |
52 | }); | 55 | }); |
53 | LOG.info("Creating shreds"); | 56 | LOG.info("Creating shreds"); |
57 | var ewl = new EventWaiterListener(); | ||
58 | |||
54 | var builder = | 59 | var builder = |
55 | JDABuilder.create(GUILD_MEMBERS, GUILD_EMOJIS, GUILD_VOICE_STATES, GUILD_MESSAGES, GUILD_MESSAGE_REACTIONS) | 60 | JDABuilder.create(GUILD_MEMBERS, GUILD_EMOJIS, GUILD_VOICE_STATES, GUILD_MESSAGES, GUILD_MESSAGE_REACTIONS) |
56 | .enableCache(VOICE_STATE, EMOTE, MEMBER_OVERRIDES) | 61 | .enableCache(VOICE_STATE, EMOTE, MEMBER_OVERRIDES) |
57 | .disableCache(ACTIVITY, CLIENT_STATUS, ONLINE_STATUS) | 62 | .disableCache(ACTIVITY, CLIENT_STATUS, ONLINE_STATUS) |
58 | .setChunkingFilter(ChunkingFilter.ALL) | 63 | .setChunkingFilter(ChunkingFilter.ALL) |
64 | .addEventListeners(ewl) | ||
59 | .setStatus(IDLE); | 65 | .setStatus(IDLE); |
60 | 66 | ||
61 | var shreds = startShreds(builder); | 67 | var shreds = startShreds(builder); |
@@ -68,7 +74,7 @@ public class Main { | |||
68 | var providers = ProviderManager.fromClasspath(shredder, data); | 74 | var providers = ProviderManager.fromClasspath(shredder, data); |
69 | var config = BotConfiguration.fromEnvironment(); | 75 | var config = BotConfiguration.fromEnvironment(); |
70 | var commands = CommandManager.fromClasspath(); | 76 | var commands = CommandManager.fromClasspath(); |
71 | var bot = new BotContext(config, commands, data, shredder, providers); | 77 | var bot = new BotContext(config, commands, data, shredder, providers, ewl); |
72 | 78 | ||
73 | bot.cron().scheduleWithFixedDelay(providers::storeAll, 2, 2, MINUTES); | 79 | bot.cron().scheduleWithFixedDelay(providers::storeAll, 2, 2, MINUTES); |
74 | getRuntime().addShutdownHook(new Thread(() -> stop(bot), "libot-shutdown")); | 80 | getRuntime().addShutdownHook(new Thread(() -> stop(bot), "libot-shutdown")); |
@@ -147,13 +153,15 @@ public class Main { | |||
147 | 153 | ||
148 | @SuppressWarnings("null") | 154 | @SuppressWarnings("null") |
149 | private static void loadEventListeners(@Nonnull Shredder shredder, @Nonnull BotContext bot) { | 155 | private static void loadEventListeners(@Nonnull Shredder shredder, @Nonnull BotContext bot) { |
150 | var listeners = scanClasspath(EventListener.class, libot.listeners.Anchor.class, c -> { | 156 | |
151 | try { | 157 | var listeners = concat(Stream.of(new MessageListener(bot), new ShredClashListener(bot), new EventLogListener()), |
152 | return c.getDeclaredConstructor(BotContext.class).newInstance(bot); | 158 | scanClasspath(EventListener.class, libot.listeners.Anchor.class, c -> { |
153 | } catch (NoSuchMethodException e) { | 159 | try { |
154 | return c.getDeclaredConstructor().newInstance(); | 160 | return c.getDeclaredConstructor(BotContext.class).newInstance(bot); |
155 | } | 161 | } catch (NoSuchMethodException e) { |
156 | }).toArray(); | 162 | return c.getDeclaredConstructor().newInstance(); |
163 | } | ||
164 | }).stream()).toArray(); | ||
157 | shredder.getShreds().stream().map(Shred::jda).forEach(j -> j.addEventListener(listeners)); | 165 | shredder.getShreds().stream().map(Shred::jda).forEach(j -> j.addEventListener(listeners)); |
158 | } | 166 | } |
159 | 167 | ||
diff --git a/src/main/java/libot/core/entities/BotContext.java b/src/main/java/libot/core/entities/BotContext.java index 25c526b..b211d2e 100644 --- a/src/main/java/libot/core/entities/BotContext.java +++ b/src/main/java/libot/core/entities/BotContext.java | |||
@@ -10,6 +10,7 @@ import libot.core.BotConfiguration; | |||
10 | import libot.core.commands.CommandManager; | 10 | import libot.core.commands.CommandManager; |
11 | import libot.core.data.DataManager; | 11 | import libot.core.data.DataManager; |
12 | import libot.core.data.providers.*; | 12 | import libot.core.data.providers.*; |
13 | import libot.core.listeners.EventWaiterListener; | ||
13 | import libot.core.shred.Shredder; | 14 | import libot.core.shred.Shredder; |
14 | 15 | ||
15 | public class BotContext { | 16 | public class BotContext { |
@@ -26,14 +27,18 @@ public class BotContext { | |||
26 | private final Shredder shredder; | 27 | private final Shredder shredder; |
27 | @Nonnull | 28 | @Nonnull |
28 | private final ProviderManager providers; | 29 | private final ProviderManager providers; |
30 | @Nonnull | ||
31 | private final EventWaiterListener ewl; | ||
29 | 32 | ||
30 | public BotContext(@Nonnull BotConfiguration config, @Nonnull CommandManager commands, @Nonnull DataManager data, | 33 | public BotContext(@Nonnull BotConfiguration config, @Nonnull CommandManager commands, @Nonnull DataManager data, |
31 | @Nonnull Shredder shredder, @Nonnull ProviderManager providers) { | 34 | @Nonnull Shredder shredder, @Nonnull ProviderManager providers, |
35 | @Nonnull EventWaiterListener ewl) { | ||
32 | this.config = config; | 36 | this.config = config; |
33 | this.commands = commands; | 37 | this.commands = commands; |
34 | this.data = data; | 38 | this.data = data; |
35 | this.shredder = shredder; | 39 | this.shredder = shredder; |
36 | this.providers = providers; | 40 | this.providers = providers; |
41 | this.ewl = ewl; | ||
37 | } | 42 | } |
38 | 43 | ||
39 | @Nonnull | 44 | @Nonnull |
@@ -67,6 +72,11 @@ public class BotContext { | |||
67 | } | 72 | } |
68 | 73 | ||
69 | @Nonnull | 74 | @Nonnull |
75 | public EventWaiterListener ewl() { | ||
76 | return this.ewl; | ||
77 | } | ||
78 | |||
79 | @Nonnull | ||
70 | public <T extends Provider<?>> T provider(@Nonnull Class<T> clazz) { | 80 | public <T extends Provider<?>> T provider(@Nonnull Class<T> clazz) { |
71 | return this.providers.get(clazz); | 81 | return this.providers.get(clazz); |
72 | } | 82 | } |
diff --git a/src/main/java/libot/core/entities/CommandContext.java b/src/main/java/libot/core/entities/CommandContext.java index 4d286b9..042e9bb 100644 --- a/src/main/java/libot/core/entities/CommandContext.java +++ b/src/main/java/libot/core/entities/CommandContext.java | |||
@@ -787,7 +787,7 @@ public class CommandContext { | |||
787 | 787 | ||
788 | public EventUtils getWaiter() { | 788 | public EventUtils getWaiter() { |
789 | if (!isWaiterInited()) | 789 | if (!isWaiterInited()) |
790 | this.waiter = new EventUtils(getUser(), getChannel()); | 790 | this.waiter = new EventUtils(getBotContext().ewl(), getUser(), getChannel()); |
791 | return this.waiter; | 791 | return this.waiter; |
792 | } | 792 | } |
793 | 793 | ||
diff --git a/src/main/java/libot/listeners/EventLogListener.java b/src/main/java/libot/core/listeners/EventLogListener.java index dceb635..fa95138 100644 --- a/src/main/java/libot/listeners/EventLogListener.java +++ b/src/main/java/libot/core/listeners/EventLogListener.java | |||
@@ -1,4 +1,4 @@ | |||
1 | package libot.listeners; | 1 | package libot.core.listeners; |
2 | 2 | ||
3 | import static org.slf4j.LoggerFactory.getLogger; | 3 | import static org.slf4j.LoggerFactory.getLogger; |
4 | 4 | ||
diff --git a/src/main/java/libot/listeners/EventWaiterListener.java b/src/main/java/libot/core/listeners/EventWaiterListener.java index a2e32e7..e4aad6c 100644 --- a/src/main/java/libot/listeners/EventWaiterListener.java +++ b/src/main/java/libot/core/listeners/EventWaiterListener.java | |||
@@ -1,4 +1,4 @@ | |||
1 | package libot.listeners; | 1 | package libot.core.listeners; |
2 | 2 | ||
3 | import static java.lang.String.format; | 3 | import static java.lang.String.format; |
4 | import static java.lang.Thread.interrupted; | 4 | import static java.lang.Thread.interrupted; |
@@ -23,17 +23,18 @@ public class EventWaiterListener implements EventListener { | |||
23 | private static final String FORMAT_WRONG_TYPE = "Something went wrong; onEvent() returned %s instead of %s"; | 23 | private static final String FORMAT_WRONG_TYPE = "Something went wrong; onEvent() returned %s instead of %s"; |
24 | 24 | ||
25 | // @formatter:off 💀 | 25 | // @formatter:off 💀 |
26 | private static final MutableIntObjectMap< | 26 | private final MutableIntObjectMap< |
27 | Triple< | 27 | Triple< |
28 | Predicate<GenericEvent>, | 28 | Predicate<GenericEvent>, |
29 | MessageLock<GenericEvent>, | 29 | MessageLock<GenericEvent>, |
30 | Predicate<Void> | 30 | Predicate<Void> |
31 | > | 31 | > |
32 | > EVENT_WAITERS = | 32 | > eventWaiters = |
33 | // @formatter:on | 33 | // @formatter:on |
34 | IntObjectMaps.mutable.<Triple<Predicate<GenericEvent>, MessageLock<GenericEvent>, Predicate<Void>>>empty() | 34 | IntObjectMaps.mutable.<Triple<Predicate<GenericEvent>, MessageLock<GenericEvent>, Predicate<Void>>>empty(); // death |
35 | .asSynchronized(); // death | 35 | |
36 | private static final AtomicInteger COUNTER = new AtomicInteger(); | 36 | private final AtomicInteger counter = new AtomicInteger(); |
37 | private final Object mutex = new Object(); | ||
37 | 38 | ||
38 | /** | 39 | /** |
39 | * Pauses the current thread and awaits a certain event.<br> | 40 | * Pauses the current thread and awaits a certain event.<br> |
@@ -63,21 +64,25 @@ public class EventWaiterListener implements EventListener { | |||
63 | * @throws InterruptedException | 64 | * @throws InterruptedException |
64 | */ | 65 | */ |
65 | @SuppressWarnings("unchecked") | 66 | @SuppressWarnings("unchecked") |
66 | public static <T extends GenericEvent> T awaitEvent(@Nonnull Predicate<GenericEvent> predicate, | 67 | public <T extends GenericEvent> T awaitEvent(@Nonnull Predicate<GenericEvent> predicate, |
67 | @Nullable Predicate<Void> nullableCleanupPredicate, int timeout, | 68 | @Nullable Predicate<Void> nullableCleanupPredicate, int timeout, |
68 | @Nullable TimeUnit timeoutUnit, | 69 | @Nullable TimeUnit timeoutUnit, |
69 | @Nonnull Class<T> eventClass) throws TimeoutException, | 70 | @Nonnull Class<T> eventClass) throws TimeoutException, |
70 | InterruptedException { | 71 | InterruptedException { |
71 | var cleanupPredicate = nullableCleanupPredicate; | 72 | var cleanupPredicate = nullableCleanupPredicate; |
72 | if (cleanupPredicate == null) | 73 | if (cleanupPredicate == null) |
73 | cleanupPredicate = p -> false; | 74 | cleanupPredicate = p -> false; |
74 | 75 | ||
75 | var lock = new MessageLock<GenericEvent>(); | 76 | var lock = new MessageLock<GenericEvent>(); |
76 | int ticket = COUNTER.getAndIncrement(); | 77 | int ticket = this.counter.getAndIncrement(); |
77 | Predicate<GenericEvent> isInstance = eventClass::isInstance; | 78 | Predicate<GenericEvent> isInstance = eventClass::isInstance; |
78 | EVENT_WAITERS.put(ticket, new ImmutableTriple<>(isInstance.and(predicate), lock, cleanupPredicate)); | 79 | synchronized (this.mutex) { |
80 | this.eventWaiters.put(ticket, new ImmutableTriple<>(isInstance.and(predicate), lock, cleanupPredicate)); | ||
81 | } | ||
79 | GenericEvent awaited = lock.receive(timeout, timeoutUnit); | 82 | GenericEvent awaited = lock.receive(timeout, timeoutUnit); |
80 | EVENT_WAITERS.remove(ticket); | 83 | synchronized (this.mutex) { |
84 | this.eventWaiters.remove(ticket); | ||
85 | } | ||
81 | sanityCheck(eventClass, awaited); | 86 | sanityCheck(eventClass, awaited); |
82 | return (T) awaited; | 87 | return (T) awaited; |
83 | } | 88 | } |
@@ -97,9 +102,11 @@ public class EventWaiterListener implements EventListener { | |||
97 | 102 | ||
98 | @Override | 103 | @Override |
99 | public void onEvent(GenericEvent event) { | 104 | public void onEvent(GenericEvent event) { |
100 | for (var e : EVENT_WAITERS.values()) { | 105 | synchronized (this.mutex) { |
101 | if (e.getLeft().test(event)) | 106 | for (var e : this.eventWaiters.values()) { |
102 | e.getMiddle().send(event); | 107 | if (e.getLeft().test(event)) |
108 | e.getMiddle().send(event); | ||
109 | } | ||
103 | } | 110 | } |
104 | } | 111 | } |
105 | 112 | ||
diff --git a/src/main/java/libot/listeners/MessageListener.java b/src/main/java/libot/core/listeners/MessageListener.java index 00fa2e4..8bccece 100644 --- a/src/main/java/libot/listeners/MessageListener.java +++ b/src/main/java/libot/core/listeners/MessageListener.java | |||
@@ -1,4 +1,4 @@ | |||
1 | package libot.listeners; | 1 | package libot.core.listeners; |
2 | 2 | ||
3 | import static libot.utils.ParseUtils.parseCommandName; | 3 | import static libot.utils.ParseUtils.parseCommandName; |
4 | 4 | ||
diff --git a/src/main/java/libot/listeners/ShredClashListener.java b/src/main/java/libot/core/listeners/ShredClashListener.java index 93b1c19..a54cbcb 100644 --- a/src/main/java/libot/listeners/ShredClashListener.java +++ b/src/main/java/libot/core/listeners/ShredClashListener.java | |||
@@ -1,4 +1,4 @@ | |||
1 | package libot.listeners; | 1 | package libot.core.listeners; |
2 | 2 | ||
3 | import javax.annotation.Nonnull; | 3 | import javax.annotation.Nonnull; |
4 | 4 | ||
diff --git a/src/main/java/libot/utils/EventUtils.java b/src/main/java/libot/utils/EventUtils.java index a5b44a4..85687cb 100644 --- a/src/main/java/libot/utils/EventUtils.java +++ b/src/main/java/libot/utils/EventUtils.java | |||
@@ -5,7 +5,7 @@ import static libot.core.Constants.*; | |||
5 | 5 | ||
6 | import javax.annotation.Nonnull; | 6 | import javax.annotation.Nonnull; |
7 | 7 | ||
8 | import libot.listeners.EventWaiterListener; | 8 | import libot.core.listeners.EventWaiterListener; |
9 | import net.dv8tion.jda.api.entities.*; | 9 | import net.dv8tion.jda.api.entities.*; |
10 | import net.dv8tion.jda.api.events.message.MessageReceivedEvent; | 10 | import net.dv8tion.jda.api.events.message.MessageReceivedEvent; |
11 | import net.dv8tion.jda.api.events.message.react.*; | 11 | import net.dv8tion.jda.api.events.message.react.*; |
@@ -15,20 +15,22 @@ public class EventUtils { | |||
15 | private int timeout; | 15 | private int timeout; |
16 | private final User user; | 16 | private final User user; |
17 | private final MessageChannel channel; | 17 | private final MessageChannel channel; |
18 | private final EventWaiterListener ewl; | ||
18 | 19 | ||
19 | public EventUtils(User user, MessageChannel channel) { | 20 | public EventUtils(EventWaiterListener ewl, User user, MessageChannel channel) { |
20 | this.user = user; | 21 | this.user = user; |
21 | this.channel = channel; | 22 | this.channel = channel; |
23 | this.ewl = ewl; | ||
22 | } | 24 | } |
23 | 25 | ||
24 | public EventUtils(User user, MessageChannel channel, int timeout) { | 26 | public EventUtils(EventWaiterListener ewl, User user, MessageChannel channel, int timeout) { |
25 | this(user, channel); | 27 | this(ewl, user, channel); |
26 | this.timeout = timeout; | 28 | this.timeout = timeout; |
27 | } | 29 | } |
28 | 30 | ||
29 | @Nonnull | 31 | @Nonnull |
30 | public MessageReaction getReaction(Message message) throws InterruptedException { | 32 | public MessageReaction getReaction(Message message) throws InterruptedException { |
31 | return EventWaiterListener.awaitEvent(p -> { | 33 | return this.ewl.awaitEvent(p -> { |
32 | GenericMessageReactionEvent e = (GenericMessageReactionEvent) p; | 34 | GenericMessageReactionEvent e = (GenericMessageReactionEvent) p; |
33 | return e.getUserIdLong() == this.user.getIdLong() && e.getMessageIdLong() == message.getIdLong(); | 35 | return e.getUserIdLong() == this.user.getIdLong() && e.getMessageIdLong() == message.getIdLong(); |
34 | }, p -> { | 36 | }, p -> { |
@@ -41,10 +43,10 @@ public class EventUtils { | |||
41 | }, this.timeout, SECONDS, GenericMessageReactionEvent.class).getReaction(); | 43 | }, this.timeout, SECONDS, GenericMessageReactionEvent.class).getReaction(); |
42 | } | 44 | } |
43 | 45 | ||
44 | @SuppressWarnings("null") | ||
45 | @Nonnull | 46 | @Nonnull |
47 | @SuppressWarnings("null") | ||
46 | public Message getMessage(String emoji) throws InterruptedException { | 48 | public Message getMessage(String emoji) throws InterruptedException { |
47 | GenericMessageReactionEvent event = EventWaiterListener.awaitEvent(p -> { | 49 | GenericMessageReactionEvent event = this.ewl.awaitEvent(p -> { |
48 | 50 | ||
49 | GenericMessageReactionEvent e = (GenericMessageReactionEvent) p; | 51 | GenericMessageReactionEvent e = (GenericMessageReactionEvent) p; |
50 | 52 | ||
@@ -58,7 +60,7 @@ public class EventUtils { | |||
58 | 60 | ||
59 | @Nonnull | 61 | @Nonnull |
60 | public Message awaitMessage(boolean ignoreBlank) throws InterruptedException { | 62 | public Message awaitMessage(boolean ignoreBlank) throws InterruptedException { |
61 | return EventWaiterListener.awaitEvent(p -> { | 63 | return this.ewl.awaitEvent(p -> { |
62 | 64 | ||
63 | var e = (MessageReceivedEvent) p; | 65 | var e = (MessageReceivedEvent) p; |
64 | 66 | ||
@@ -70,7 +72,7 @@ public class EventUtils { | |||
70 | } | 72 | } |
71 | 73 | ||
72 | public boolean awaitBoolean(@Nonnull Message question) throws InterruptedException { | 74 | public boolean awaitBoolean(@Nonnull Message question) throws InterruptedException { |
73 | boolean result = ACCEPT_EMOJI.equals(EventWaiterListener.awaitEvent(p -> { | 75 | boolean result = ACCEPT_EMOJI.equals(this.ewl.awaitEvent(p -> { |
74 | MessageReactionAddEvent e = (MessageReactionAddEvent) p; | 76 | MessageReactionAddEvent e = (MessageReactionAddEvent) p; |
75 | String emote = e.getReactionEmote().getName(); | 77 | String emote = e.getReactionEmote().getName(); |
76 | 78 | ||