From 585f3e59d5fcb6a5caeb6a467e74b9dc17bd18d0 Mon Sep 17 00:00:00 2001 From: flifloo Date: Thu, 23 Jul 2020 21:42:30 +0200 Subject: [PATCH] Connection of poll extension to database to avoid lost polls --- db/Polls.py | 19 ++++++++++++++ db/__init__.py | 1 + extensions/poll.py | 63 ++++++++++++++++++++++++++-------------------- 3 files changed, 56 insertions(+), 27 deletions(-) create mode 100644 db/Polls.py diff --git a/db/Polls.py b/db/Polls.py new file mode 100644 index 0000000..41f4f0b --- /dev/null +++ b/db/Polls.py @@ -0,0 +1,19 @@ +from db import Base +from sqlalchemy import Column, Integer, BigInteger, String, Boolean + + +class Polls(Base): + __tablename__ = "polls" + id = Column(Integer, primary_key=True) + message = Column(BigInteger, nullable=False, unique=True) + channel = Column(BigInteger, nullable=False) + author = Column(BigInteger, nullable=False) + reactions = Column(String, nullable=False) + multi = Column(Boolean, nullable=False, default=False) + + def __init__(self, message: int, channel: int, author: int, reactions: [str], multi: bool = False): + self.message = message + self.channel = channel + self.author = author + self.reactions = str(reactions) + self.multi = multi diff --git a/db/__init__.py b/db/__init__.py index c6411db..50ebb32 100644 --- a/db/__init__.py +++ b/db/__init__.py @@ -9,4 +9,5 @@ from db.Task import Task from db.Greetings import Greetings from db.Presentation import Presentation from db.RoRec import RoRec +from db.Polls import Polls Base.metadata.create_all(engine) diff --git a/extensions/poll.py b/extensions/poll.py index d31486a..e127d15 100644 --- a/extensions/poll.py +++ b/extensions/poll.py @@ -1,8 +1,10 @@ from datetime import datetime from discord.ext import commands -from discord import Member, Embed, Reaction +from discord import Embed, RawReactionActionEvent +from discord.ext.commands import BadArgument +import db from administrator.logger import logger @@ -17,7 +19,6 @@ REACTIONS.append("\U0001F51F") class Poll(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot - self.polls = {} def description(self): return "Create poll with a simple command" @@ -33,7 +34,7 @@ class Poll(commands.Cog): multi = True choices = choices[1:] if len(choices) == 0 or len(choices) > 11: - await ctx.message.add_reaction("\u274C") + raise BadArgument() else: embed = Embed(title=f"Poll: {name}") embed.set_author(name=str(ctx.author), icon_url=ctx.author.avatar_url) @@ -44,8 +45,10 @@ class Poll(commands.Cog): reactions = REACTIONS[0:len(choices)] + ["\U0001F5D1"] for reaction in reactions: await message.add_reaction(reaction) - message = await message.channel.fetch_message(message.id) - self.polls[message.id] = {"multi": multi, "message": message, "author": ctx.message.author.id} + s = db.Session() + s.add(db.Polls(message.id, ctx.channel.id, ctx.message.author.id, reactions, multi)) + s.commit() + s.close() await ctx.message.delete() @poll.group("help", pass_context=True) @@ -59,30 +62,35 @@ class Poll(commands.Cog): await ctx.send(embed=embed) @commands.Cog.listener() - async def on_reaction_add(self, reaction: Reaction, user: Member): - if not user.bot and reaction.message.id in self.polls: - if reaction not in self.polls[reaction.message.id]["message"].reactions: - await reaction.remove(user) - elif str(reaction.emoji) == "\U0001F5D1": - if user.id != self.polls[reaction.message.id]["author"]: - await reaction.remove(user) - else: - await self.close_poll(reaction.message.id) - elif not self.polls[reaction.message.id]["multi"]: - f = False - for r in reaction.message.reactions: - if str(r.emoji) != str(reaction.emoji): - async for u in r.users(): - if u == user: - await r.remove(user) - f = True + async def on_raw_reaction_add(self, payload: RawReactionActionEvent): + if not payload.member.bot: + s = db.Session() + p = s.query(db.Polls).filter(db.Polls.message == payload.message_id).first() + if p: + message = await self.bot.get_channel(p.channel).fetch_message(p.message) + if str(payload.emoji) not in eval(p.reactions): + await message.remove_reaction(payload.emoji, payload.member) + elif str(payload.emoji) == "\U0001F5D1": + if payload.member.id != p.author: + await message.remove_reaction(payload.emoji, payload.member) + else: + await self.close_poll(s, p) + elif not p.multi: + f = False + for r in message.reactions: + if str(r.emoji) != str(payload.emoji): + async for u in r.users(): + if u == payload.member: + await r.remove(payload.member) + f = True + break + if f: break - if f: - break + s.close() - async def close_poll(self, id: int): + async def close_poll(self, session: db.Session, poll: db.Polls): time = datetime.now() - message = await self.polls[id]["message"].channel.fetch_message(id) + message = await self.bot.get_channel(poll.channel).fetch_message(poll.message) reactions = message.reactions await message.clear_reactions() embed = message.embeds[0] @@ -90,7 +98,8 @@ class Poll(commands.Cog): embed.set_field_at(i, name=f"{f.name} - {reactions[i].count-1}", value=f.value, inline=False) embed.set_footer(text=embed.footer.text + "\n" + f"Close: {time.strftime('%d/%m/%Y %H:%M')}") await message.edit(embed=embed) - del self.polls[id] + session.delete(poll) + session.commit() def setup(bot):