diff --git a/db/Polls.py b/db/Polls.py index 41f4f0b..07ee45f 100644 --- a/db/Polls.py +++ b/db/Polls.py @@ -7,13 +7,15 @@ class Polls(Base): id = Column(Integer, primary_key=True) message = Column(BigInteger, nullable=False, unique=True) channel = Column(BigInteger, nullable=False) + guild = Column(BigInteger, nullable=False) author = Column(BigInteger, nullable=False) reactions = Column(String, nullable=False) multi = Column(Boolean, nullable=False, default=False) - def __init__(self, message: int, channel: int, author: int, reactions: [str], multi: bool = False): + def __init__(self, message: int, channel: int, guild: int, author: int, reactions: [str], multi: bool = False): self.message = message self.channel = channel + self.guild = guild self.author = author self.reactions = str(reactions) self.multi = multi diff --git a/extensions/poll.py b/extensions/poll.py index e127d15..39ba857 100644 --- a/extensions/poll.py +++ b/extensions/poll.py @@ -1,7 +1,8 @@ from datetime import datetime +from discord.abc import GuildChannel from discord.ext import commands -from discord import Embed, RawReactionActionEvent +from discord import Embed, RawReactionActionEvent, RawMessageDeleteEvent, RawBulkMessageDeleteEvent, TextChannel, Guild from discord.ext.commands import BadArgument import db @@ -46,13 +47,12 @@ class Poll(commands.Cog): for reaction in reactions: await message.add_reaction(reaction) s = db.Session() - s.add(db.Polls(message.id, ctx.channel.id, ctx.message.author.id, reactions, multi)) + s.add(db.Polls(message.id, ctx.channel.id, ctx.guild.id, ctx.message.author.id, reactions, multi)) s.commit() s.close() await ctx.message.delete() @poll.group("help", pass_context=True) - @commands.guild_only() async def poll_help(self, ctx: commands.Context): embed = Embed(title="Poll help") embed.add_field(name="poll [multi|m] ... ", @@ -101,6 +101,40 @@ class Poll(commands.Cog): session.delete(poll) session.commit() + @commands.Cog.listener() + async def on_raw_message_delete(self, message: RawMessageDeleteEvent): + s = db.Session() + p = s.query(db.Polls).filter(db.Polls.message == message.message_id).first() + if p: + s.delete(p) + s.commit() + s.close() + + @commands.Cog.listener() + async def on_raw_bulk_message_delete(self, messages: RawBulkMessageDeleteEvent): + s = db.Session() + for p in s.query(db.Polls).filter(db.Polls.message.in_(messages.message_ids)).all(): + s.delete(p) + s.commit() + s.close() + + @commands.Cog.listener() + async def on_guild_channel_delete(self, channel: GuildChannel): + if isinstance(channel, TextChannel): + s = db.Session() + for p in s.query(db.Polls).filter(db.Polls.channel == channel.id).all(): + s.delete(p) + s.commit() + s.close() + + @commands.Cog.listener() + async def on_guild_remove(self, guild: Guild): + s = db.Session() + for p in s.query(db.Polls).filter(db.Polls.guild == guild.id).all(): + s.delete(p) + s.commit() + s.close() + def setup(bot): logger.info(f"Loading...")