Connection of poll extension to database to avoid lost polls
This commit is contained in:
parent
0dc456bfa8
commit
585f3e59d5
3 changed files with 56 additions and 27 deletions
19
db/Polls.py
Normal file
19
db/Polls.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from db import Base
|
||||
from sqlalchemy import Column, Integer, BigInteger, String, Boolean
|
||||
|
||||
|
||||
class Polls(Base):
|
||||
__tablename__ = "polls"
|
||||
id = Column(Integer, primary_key=True)
|
||||
message = Column(BigInteger, nullable=False, unique=True)
|
||||
channel = Column(BigInteger, nullable=False)
|
||||
author = Column(BigInteger, nullable=False)
|
||||
reactions = Column(String, nullable=False)
|
||||
multi = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
def __init__(self, message: int, channel: int, author: int, reactions: [str], multi: bool = False):
|
||||
self.message = message
|
||||
self.channel = channel
|
||||
self.author = author
|
||||
self.reactions = str(reactions)
|
||||
self.multi = multi
|
|
@ -9,4 +9,5 @@ from db.Task import Task
|
|||
from db.Greetings import Greetings
|
||||
from db.Presentation import Presentation
|
||||
from db.RoRec import RoRec
|
||||
from db.Polls import Polls
|
||||
Base.metadata.create_all(engine)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from datetime import datetime
|
||||
|
||||
from discord.ext import commands
|
||||
from discord import Member, Embed, Reaction
|
||||
from discord import Embed, RawReactionActionEvent
|
||||
from discord.ext.commands import BadArgument
|
||||
|
||||
import db
|
||||
from administrator.logger import logger
|
||||
|
||||
|
||||
|
@ -17,7 +19,6 @@ REACTIONS.append("\U0001F51F")
|
|||
class Poll(commands.Cog):
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
self.polls = {}
|
||||
|
||||
def description(self):
|
||||
return "Create poll with a simple command"
|
||||
|
@ -33,7 +34,7 @@ class Poll(commands.Cog):
|
|||
multi = True
|
||||
choices = choices[1:]
|
||||
if len(choices) == 0 or len(choices) > 11:
|
||||
await ctx.message.add_reaction("\u274C")
|
||||
raise BadArgument()
|
||||
else:
|
||||
embed = Embed(title=f"Poll: {name}")
|
||||
embed.set_author(name=str(ctx.author), icon_url=ctx.author.avatar_url)
|
||||
|
@ -44,8 +45,10 @@ class Poll(commands.Cog):
|
|||
reactions = REACTIONS[0:len(choices)] + ["\U0001F5D1"]
|
||||
for reaction in reactions:
|
||||
await message.add_reaction(reaction)
|
||||
message = await message.channel.fetch_message(message.id)
|
||||
self.polls[message.id] = {"multi": multi, "message": message, "author": ctx.message.author.id}
|
||||
s = db.Session()
|
||||
s.add(db.Polls(message.id, ctx.channel.id, ctx.message.author.id, reactions, multi))
|
||||
s.commit()
|
||||
s.close()
|
||||
await ctx.message.delete()
|
||||
|
||||
@poll.group("help", pass_context=True)
|
||||
|
@ -59,30 +62,35 @@ class Poll(commands.Cog):
|
|||
await ctx.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_reaction_add(self, reaction: Reaction, user: Member):
|
||||
if not user.bot and reaction.message.id in self.polls:
|
||||
if reaction not in self.polls[reaction.message.id]["message"].reactions:
|
||||
await reaction.remove(user)
|
||||
elif str(reaction.emoji) == "\U0001F5D1":
|
||||
if user.id != self.polls[reaction.message.id]["author"]:
|
||||
await reaction.remove(user)
|
||||
else:
|
||||
await self.close_poll(reaction.message.id)
|
||||
elif not self.polls[reaction.message.id]["multi"]:
|
||||
f = False
|
||||
for r in reaction.message.reactions:
|
||||
if str(r.emoji) != str(reaction.emoji):
|
||||
async for u in r.users():
|
||||
if u == user:
|
||||
await r.remove(user)
|
||||
f = True
|
||||
async def on_raw_reaction_add(self, payload: RawReactionActionEvent):
|
||||
if not payload.member.bot:
|
||||
s = db.Session()
|
||||
p = s.query(db.Polls).filter(db.Polls.message == payload.message_id).first()
|
||||
if p:
|
||||
message = await self.bot.get_channel(p.channel).fetch_message(p.message)
|
||||
if str(payload.emoji) not in eval(p.reactions):
|
||||
await message.remove_reaction(payload.emoji, payload.member)
|
||||
elif str(payload.emoji) == "\U0001F5D1":
|
||||
if payload.member.id != p.author:
|
||||
await message.remove_reaction(payload.emoji, payload.member)
|
||||
else:
|
||||
await self.close_poll(s, p)
|
||||
elif not p.multi:
|
||||
f = False
|
||||
for r in message.reactions:
|
||||
if str(r.emoji) != str(payload.emoji):
|
||||
async for u in r.users():
|
||||
if u == payload.member:
|
||||
await r.remove(payload.member)
|
||||
f = True
|
||||
break
|
||||
if f:
|
||||
break
|
||||
if f:
|
||||
break
|
||||
s.close()
|
||||
|
||||
async def close_poll(self, id: int):
|
||||
async def close_poll(self, session: db.Session, poll: db.Polls):
|
||||
time = datetime.now()
|
||||
message = await self.polls[id]["message"].channel.fetch_message(id)
|
||||
message = await self.bot.get_channel(poll.channel).fetch_message(poll.message)
|
||||
reactions = message.reactions
|
||||
await message.clear_reactions()
|
||||
embed = message.embeds[0]
|
||||
|
@ -90,7 +98,8 @@ class Poll(commands.Cog):
|
|||
embed.set_field_at(i, name=f"{f.name} - {reactions[i].count-1}", value=f.value, inline=False)
|
||||
embed.set_footer(text=embed.footer.text + "\n" + f"Close: {time.strftime('%d/%m/%Y %H:%M')}")
|
||||
await message.edit(embed=embed)
|
||||
del self.polls[id]
|
||||
session.delete(poll)
|
||||
session.commit()
|
||||
|
||||
|
||||
def setup(bot):
|
||||
|
|
Reference in a new issue