diff --git a/db/Task.py b/db/Task.py new file mode 100644 index 0000000..dae228c --- /dev/null +++ b/db/Task.py @@ -0,0 +1,19 @@ +from db import Base +from sqlalchemy import Column, Integer, String, BigInteger, Date +from datetime import datetime + + +class Task(Base): + __tablename__ = "tasks" + id = Column(Integer, primary_key=True) + message = Column(String, nullable=False) + user = Column(BigInteger, nullable=False) + channel = Column(BigInteger, nullable=False) + date = Column(Date, nullable=False) + creation_date = Column(Date, default=datetime.now()) + + def __init__(self, message: str, user: int, channel: int, date: datetime): + self.message = message + self.user = user + self.channel = channel + self.date = date diff --git a/db/__init__.py b/db/__init__.py index d414ccb..ca98567 100644 --- a/db/__init__.py +++ b/db/__init__.py @@ -5,5 +5,5 @@ from sqlalchemy.ext.declarative import declarative_base engine = create_engine(config.get("db")) Session = sessionmaker(bind=engine) Base = declarative_base() -#from db.foo import Barr +from db.Task import Task Base.metadata.create_all(engine) diff --git a/extensions/reminders.py b/extensions/reminders.py index 11fa4e6..f35c514 100644 --- a/extensions/reminders.py +++ b/extensions/reminders.py @@ -7,6 +7,7 @@ from discord.ext.commands import CommandNotFound, BadArgument, MissingRequiredAr from discord.ext import tasks from bot_bde.logger import logger +from bot_bde import db extension_name = "reminders" @@ -14,7 +15,8 @@ logger = logger.getChild(extension_name) def time_pars(s: str) -> timedelta: - match = re.fullmatch(r"(?:([0-9]+)W)*(?:([0-9]+)D)*(?:([0-9]+)H)*(?:([0-9]+)M)*(?:([0-9]+)S)*", s.upper().replace(" ", "").strip()) + match = re.fullmatch(r"(?:([0-9]+)W)*(?:([0-9]+)D)*(?:([0-9]+)H)*(?:([0-9]+)M)*(?:([0-9]+)S)*", + s.upper().replace(" ", "").strip()) if match: w, d, h, m, s = match.groups() if any([w, d, h, m, s]): @@ -26,7 +28,6 @@ def time_pars(s: str) -> timedelta: class Reminders(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot - self.tasks = [] @commands.group("reminder", pass_context=True) async def reminder(self, ctx: commands.Context): @@ -47,13 +48,10 @@ class Reminders(commands.Cog): async def reminder_add(self, ctx: commands.Context, message: str, time: str): time = time_pars(time) now = datetime.now() - self.tasks.append({ - "date": now + time, - "create": now, - "user": ctx.author.id, - "message": message, - "channel": ctx.channel.id, - }) + s = db.Session() + s.add(db.Task(message, ctx.author.id, ctx.channel.id, now + time)) + s.commit() + s.close() hours, seconds = divmod(time.seconds, 3600) minutes, seconds = divmod(seconds, 60) @@ -65,37 +63,44 @@ class Reminders(commands.Cog): @reminder.group("list", pass_context=True) async def reminder_list(self, ctx: commands.Context): embed = Embed(title="Tasks list") - for i, t in enumerate(self.tasks): - if t["user"] == ctx.author.id: - embed.add_field(name=t["date"], value=f"N°{i} | {t['message']}", inline=False) + s = db.Session() + for t in s.query(db.Task).filter(db.Task.user == ctx.author.id).all(): + embed.add_field(name=f"N°{t.id} | {t.date}", value=f"{t.message}", inline=False) + s.close() await ctx.send(embed=embed) @reminder.group("remove", pass_context=True) async def reminder_remove(self, ctx: commands.Context, n: int = None): - tasks =list(filter(lambda t: t["user"] == ctx.author.id, self.tasks)) if n is None: await ctx.invoke(self.reminder_list) - elif n >= len(tasks): - raise BadArgument() else: - del self.tasks[n] - await ctx.message.add_reaction("\U0001f44d") + s = db.Session() + t = s.query(db.Task).filter(db.Task.id == n).first() + if t and t.user == ctx.author.id: + s.delete(t) + s.commit() + s.close() + await ctx.message.add_reaction("\U0001f44d") + else: + s.close() + raise BadArgument() @tasks.loop(minutes=1) async def reminders_loop(self): - trash = [] - for t in self.tasks: - if t["date"] <= datetime.now(): - self.bot.loop.create_task(self.reminder_exec(t)) - trash.append(t) + s = db.Session() + for t in s.query(db.Task).filter(db.Task.date <= datetime.now()).all(): + self.bot.loop.create_task(self.reminder_exec(t)) + s.delete(t) - for t in trash: - del self.tasks[self.tasks.index(t)] + s.commit() + s.close() - async def reminder_exec(self, task: dict): + async def reminder_exec(self, task: db.Task): embed = Embed(title="You have a reminder !") - embed.add_field(name=task["date"], value=task["message"]) - await self.bot.get_channel(task["channel"]).send(f"<@{task['user']}>", embed=embed) + user = self.bot.get_user(task.user) + embed.set_author(name=f"{user.name}#{user.discriminator}", icon_url=user.avatar_url) + embed.add_field(name=str(task.date), value=task.message) + await (await self.bot.get_channel(task.channel).send(f"{user.mention}", embed=embed)).edit(content="") @commands.Cog.listener() async def on_command_error(self, ctx: commands.Context, error): @@ -126,6 +131,7 @@ def setup(bot): def teardown(bot): logger.info(f"Unloading...") try: + bot.get_cog("Reminders").reminders_loop.stop() bot.remove_cog("Reminders") except Exception as e: logger.error(f"Error unloading: {e}")