From 1b5d8410aeb28edf9b778a815b7c8a51eb032b09 Mon Sep 17 00:00:00 2001 From: flifloo Date: Wed, 3 Feb 2021 14:46:22 +0100 Subject: [PATCH] Updating check for slash commands --- administrator/check.py | 71 +++++++++++++++++++++++++++++++++++------ extensions/extension.py | 7 ++-- extensions/greetings.py | 21 +++++++++--- requirements.txt | 1 + 4 files changed, 84 insertions(+), 16 deletions(-) diff --git a/administrator/check.py b/administrator/check.py index 3d1cb6a..0e8d902 100644 --- a/administrator/check.py +++ b/administrator/check.py @@ -1,4 +1,8 @@ +import functools + +from discord import Permissions from discord.ext import commands +from discord.ext.commands import NoPrivateMessage, NotOwner, MissingPermissions import db @@ -8,12 +12,61 @@ class ExtensionDisabled(commands.CheckFailure): def is_enabled(): - async def check(ctx: commands.Context): - if ctx.command.cog and ctx.guild: - s = db.Session() - es = s.query(db.ExtensionState).get((ctx.command.cog.qualified_name, ctx.guild.id)) - s.close() - if es and not es.state: - raise ExtensionDisabled() - return True - return commands.check(check) + def check(func): + @functools.wraps(func) + async def wrapped(*args): + ctx = args[1] + if ctx.guild: + s = db.Session() + es = s.query(db.ExtensionState).get((args[0].qualified_name, ctx.guild.id)) + s.close() + if es and not es.state: + raise ExtensionDisabled() + return await func(*args) + return wrapped + return check + + +def is_owner(): + def check(func): + @functools.wraps(func) + async def wrapped(*args): + ctx = args[1] + if not await ctx._discord.is_owner(ctx.author): + raise NotOwner('You do not own this bot.') + return await func(*args) + return wrapped + return check + + +def guild_only(): + def check(func): + @functools.wraps(func) + async def wrapped(*args): + if args[1].guild is None: + raise NoPrivateMessage() + return await func(*args) + return wrapped + return check + + +def has_permissions(**perms): + invalid = set(perms) - set(Permissions.VALID_FLAGS) + if invalid: + raise TypeError('Invalid permission(s): %s' % (', '.join(invalid))) + + def check(func): + @functools.wraps(func) + async def wrapped(*args): + ctx = args[1] + ch = ctx.channel + permissions = ch.permissions_for(ctx.author) + + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + + if not missing: + return await func(*args) + + raise MissingPermissions(missing) + return wrapped + return check diff --git a/extensions/extension.py b/extensions/extension.py index 93a5cad..f6635cb 100644 --- a/extensions/extension.py +++ b/extensions/extension.py @@ -8,6 +8,7 @@ from discord_slash.utils import manage_commands import db from administrator import slash +from administrator.check import has_permissions from administrator.logger import logger @@ -24,7 +25,7 @@ class Extension(commands.Cog): return "Manage bot's extensions" @cog_ext.cog_subcommand(base="extension", name="list", description="List all enabled extensions") - @commands.has_guild_permissions(administrator=True) + @has_permissions(administrator=True) async def extension_list(self, ctx: SlashContext): s = db.Session() embed = Embed(title="Extensions list") @@ -38,7 +39,7 @@ class Extension(commands.Cog): description="Enable an extensions", options=[manage_commands.create_option("extension", "The extension to enable", SlashCommandOptionType.STRING, True)]) - @commands.has_guild_permissions(administrator=True) + @has_permissions(administrator=True) async def extension_enable(self, ctx: SlashContext, name: str): s = db.Session() es = s.query(db.ExtensionState).get((name, ctx.guild.id)) @@ -59,7 +60,7 @@ class Extension(commands.Cog): description="Disable an extensions", options=[manage_commands.create_option("extension", "The extension to disable", SlashCommandOptionType.STRING, True)]) - @commands.has_guild_permissions(administrator=True) + @has_permissions(administrator=True) async def extension_disable(self, ctx: SlashContext, name: str): s = db.Session() es = s.query(db.ExtensionState).get((name, ctx.guild.id)) diff --git a/extensions/greetings.py b/extensions/greetings.py index 20c554e..ad59907 100644 --- a/extensions/greetings.py +++ b/extensions/greetings.py @@ -3,6 +3,7 @@ from discord import Member, Embed, Forbidden from discord_slash import cog_ext, SlashContext, SlashCommandOptionType from discord_slash.utils import manage_commands +from administrator.check import is_enabled, guild_only, has_permissions from administrator.logger import logger from administrator import db, slash from administrator.utils import event_is_enabled @@ -20,8 +21,11 @@ class Greetings(commands.Cog): def description(self): return "Setup join and leave message" - @cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="help", + @cog_ext.cog_subcommand(base="greetings", name="help", description="Help about greetings") + @is_enabled() + @guild_only() + @has_permissions(manage_guild=True) async def greetings_help(self, ctx: SlashContext): embed = Embed(title="Greetings help") embed.add_field(name="set ", value="Set the greetings message\n" @@ -31,7 +35,7 @@ class Greetings(commands.Cog): embed.add_field(name="toggle ", value="Enable or disable the greetings message", inline=False) await ctx.send(embeds=[embed]) - @cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="set", + @cog_ext.cog_subcommand(base="greetings", name="set", description="Set the greetings message\n`{}` will be replace by the username", options=[ manage_commands.create_option("type", "The join or leave message", @@ -41,6 +45,9 @@ class Greetings(commands.Cog): manage_commands.create_option("message", "The message", SlashCommandOptionType.STRING, True) ]) + @is_enabled() + @guild_only() + @has_permissions(manage_guild=True) async def greetings_set(self, ctx: SlashContext, message_type: str, message: str): s = db.Session() m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first() @@ -52,12 +59,15 @@ class Greetings(commands.Cog): s.commit() await ctx.send(content="\U0001f44d") - @cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="show", + @cog_ext.cog_subcommand(base="greetings", name="show", description="Show the greetings message", options=[manage_commands.create_option("type", "The join or leave message", SlashCommandOptionType.STRING, True, [manage_commands.create_choice("join", "join"), manage_commands.create_choice("leave", "leave")])]) + @is_enabled() + @guild_only() + @has_permissions(manage_guild=True) async def greetings_show(self, ctx: SlashContext, message_type: str): s = db.Session() m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first() @@ -70,12 +80,15 @@ class Greetings(commands.Cog): else: await ctx.send(content=m.leave_msg(str(ctx.author))) - @cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="toggle", + @cog_ext.cog_subcommand(base="greetings", name="toggle", description="Enable or disable the greetings message", options=[manage_commands.create_option("type", "The join or leave message", SlashCommandOptionType.STRING, True, [manage_commands.create_choice("join", "join"), manage_commands.create_choice("leave", "leave")])]) + @is_enabled() + @guild_only() + @has_permissions(manage_guild=True) async def greetings_toggle(self, ctx: SlashContext, message_type: str): s = db.Session() m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first() diff --git a/requirements.txt b/requirements.txt index 56d1dbc..8f95633 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ async-timeout==3.0.1 attrs==20.2.0 chardet==3.0.4 discord==1.0.1 +discord-py-slash-command==1.0.8.5 discord.py==1.5.1 feedparser==6.0.2 idna==2.10