From 3e3e0369fbd455958d0fb3dc7a23f75092b35c41 Mon Sep 17 00:00:00 2001 From: flifloo Date: Thu, 28 May 2020 15:55:37 +0200 Subject: [PATCH] Add permission check on calendar commands --- extensions/calendar.py | 50 +++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/extensions/calendar.py b/extensions/calendar.py index 8281b96..410cb5e 100644 --- a/extensions/calendar.py +++ b/extensions/calendar.py @@ -7,11 +7,12 @@ import requests from discord import Embed, DMChannel, TextChannel from discord.ext import commands from discord.ext import tasks -from discord.ext.commands import CommandNotFound, BadArgument, MissingRequiredArgument +from discord.ext.commands import CommandNotFound, BadArgument, MissingRequiredArgument, MissingPermissions from bot_bde import db from bot_bde.logger import logger + extension_name = "calendar" logger = logger.getChild(extension_name) url_re = re.compile(r"http:\/\/adelb\.univ-lyon1\.fr\/jsp\/custom\/modules\/plannings\/anonymous_cal\.jsp\?resources=" @@ -27,24 +28,18 @@ def query_calendar(name: str, guild: int) -> db.Calendar: return c -async def get_one_mention(ctx: commands.Context): - if ctx.message.channel_mentions and ctx.message.mentions: - raise BadArgument() - elif ctx.message.channel_mentions: - if len(ctx.message.channel_mentions) > 1: +async def get_one_text_channel(ctx: commands.Context): + if ctx.message.channel_mentions: + if not ctx.channel.permissions_for(ctx.author).manage_channels: + raise MissingPermissions(["manage_channels"]) + elif len(ctx.message.channel_mentions) > 1: raise BadArgument() else: m = ctx.message.channel_mentions[0].id - elif ctx.message.mentions: - if len(ctx.message.mentions) > 1: - raise BadArgument() - else: - m = ctx.message.mentions[0] - if not m.dm_channel: - await m.create_dm() - m = m.dm_channel.id else: - m = ctx.channel.id + if not ctx.author.dm_channel: + await ctx.author.create_dm() + m = ctx.author.dm_channel.id return m @@ -71,6 +66,7 @@ class Calendar(commands.Cog): @calendar.group("define", pass_context=True) @commands.guild_only() + @commands.has_permissions(manage_channels=True) async def calendar_define(self, ctx: commands.Context, name: str, url: str): try: ics.Calendar(requests.get(url).text) @@ -101,6 +97,7 @@ class Calendar(commands.Cog): @calendar.group("remove", pass_context=True) @commands.guild_only() + @commands.has_permissions(manage_channels=True) async def calendar_remove(self, ctx: commands.Context, name: str = None): if name is None: await ctx.invoke(self.calendar_list) @@ -164,10 +161,10 @@ class Calendar(commands.Cog): @calendar_notify.group("help", pass_context=True) async def calendar_notify_help(self, ctx: commands.Context): embed = Embed(title="Calendar notify help") - embed.add_field(name="calendar notify add [#channel|@user]", - value="Notify the current channel or the giver channel/user of calendar events", inline=False) - embed.add_field(name="calendar notify remove [#channel|@user]", - value="Remove the calendar notify of the current channel or the given channel/user", + embed.add_field(name="calendar notify add [#channel]", + value="Notify you or the giver channel of calendar events", inline=False) + embed.add_field(name="calendar notify remove [#channel]", + value="Remove the calendar notify of the current user or the given channel", inline=False) embed.add_field(name="calendar notify list [name]", value="List all notify of all calendar or the given one", inline=False) @@ -176,7 +173,7 @@ class Calendar(commands.Cog): @calendar_notify.group("add", pass_context=True) @commands.guild_only() async def calendar_notify_set(self, ctx: commands.Context, name: str): - m = await get_one_mention(ctx) + m = await get_one_text_channel(ctx) s = db.Session() c = query_calendar(name, ctx.guild.id) n = s.query(db.CalendarNotify).filter(db.CalendarNotify.channel == m) \ @@ -192,8 +189,9 @@ class Calendar(commands.Cog): await ctx.message.add_reaction("\U0001f44d") @calendar_notify.group("remove", pass_context=True) + @commands.guild_only() async def calendar_notify_remove(self, ctx: commands.Context, name: str): - m = await get_one_mention(ctx) + m = await get_one_text_channel(ctx) s = db.Session() c = query_calendar(name, ctx.guild.id) n = s.query(db.CalendarNotify).filter(db.CalendarNotify.channel == m) \ @@ -228,13 +226,6 @@ class Calendar(commands.Cog): embed.add_field(name=c.name, value="\n".join(notify) or "Nothing here", inline=False) await ctx.send(embed=embed) - @calendar_notify.group("trigger", pass_context=True) - @commands.guild_only() - async def calendar_notify_trigger(self, ctx: commands.Context, name: str): - c = query_calendar(name, ctx.guild.id) - now = datetime.now() - await c.notify(self.bot, c.events(now, now)[0]) - @tasks.loop(minutes=1) async def calendar_notify_loop(self): s = db.Session() @@ -253,6 +244,9 @@ class Calendar(commands.Cog): or isinstance(error, MissingRequiredArgument): await ctx.message.add_reaction("\u2753") await ctx.message.delete(delay=30) + elif isinstance(error, MissingPermissions): + await ctx.message.add_reaction("\u274c") + await ctx.message.delete(delay=30) else: await ctx.send("An error occurred !") raise error