diff --git a/administrator/__init__.py b/administrator/__init__.py index 00a4ae1..5ef856f 100644 --- a/administrator/__init__.py +++ b/administrator/__init__.py @@ -6,7 +6,7 @@ import db from discord.ext import commands bot = commands.Bot(command_prefix=config.get("prefix"), intents=Intents.all()) -slash = SlashCommand(bot, auto_register=True) +slash = SlashCommand(bot, auto_register=True, auto_delete=True) import extensions diff --git a/extensions/extension.py b/extensions/extension.py index 653cf90..93a5cad 100644 --- a/extensions/extension.py +++ b/extensions/extension.py @@ -2,9 +2,12 @@ from traceback import format_exc from discord.ext import commands from discord import Embed, Guild -from discord.ext.commands import MissingPermissions, BadArgument, CommandError +from discord.ext.commands import BadArgument +from discord_slash import cog_ext, SlashContext, SlashCommandOptionType +from discord_slash.utils import manage_commands import db +from administrator import slash from administrator.logger import logger @@ -15,69 +18,66 @@ logger = logger.getChild(extension_name) class Extension(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot + slash.get_cog_commands(self) def description(self): return "Manage bot's extensions" - @commands.group("extension", pass_context=True) - async def extension(self, ctx: commands.Context): - if ctx.invoked_subcommand is None: - await ctx.invoke(self.extension_help) - - @extension.group("help", pass_context=True) - async def extension_help(self, ctx: commands.Context): - embed = Embed(title="Extension help") - for c, n, v in [[self.extension_list, "extension list", "List all enabled extensions"], - [self.extension_enable, "extension enable", "Enable an extensions"], - [self.extension_disable, "extension disable", "Disable an extensions"], - [self.extension_loaded, "extension loaded", "List all loaded extensions"], - [self.extension_load, "extension load ", "Load an extension"], - [self.extension_unload, "extension unload ", "Unload an extension"], - [self.extension_reload, "extension reload ", "Reload an extension"]]: - try: - if await c.can_run(ctx): - embed.add_field(name=n, value=v, inline=False) - except CommandError: - pass - - if not embed.fields: - raise MissingPermissions("") - await ctx.send(embed=embed) - - @extension.group("list", pass_context=True) + @cog_ext.cog_subcommand(base="extension", name="list", description="List all enabled extensions") @commands.has_guild_permissions(administrator=True) - async def extension_list(self, ctx: commands.Context): + async def extension_list(self, ctx: SlashContext): s = db.Session() embed = Embed(title="Extensions list") for es in s.query(db.ExtensionState).filter(db.ExtensionState.guild_id == ctx.guild.id): embed.add_field(name=es.extension_name, value="Enable" if es.state else "Disable") - await ctx.send(embed=embed) + s.close() + await ctx.send(embeds=[embed]) - @extension.group("enable", pass_context=True) + @cog_ext.cog_subcommand(base="extension", + name="enable", + description="Enable an extensions", + options=[manage_commands.create_option("extension", "The extension to enable", + SlashCommandOptionType.STRING, True)]) @commands.has_guild_permissions(administrator=True) - async def extension_enable(self, ctx: commands.Context, name: str): + async def extension_enable(self, ctx: SlashContext, name: str): s = db.Session() es = s.query(db.ExtensionState).get((name, ctx.guild.id)) - if not es or es.state: + if not es: raise BadArgument() - es.state = True - s.add(es) - s.commit() - s.close() - await ctx.message.add_reaction("\U0001f44d") + elif es.state: + message = "Extension already enabled" + else: + es.state = True + s.add(es) + s.commit() + s.close() + message = "\U0001f44d" + await ctx.send(content=message) - @extension.group("disable", pass_context=True) + @cog_ext.cog_subcommand(base="extension", + name="disable", + description="Disable an extensions", + options=[manage_commands.create_option("extension", "The extension to disable", + SlashCommandOptionType.STRING, True)]) @commands.has_guild_permissions(administrator=True) - async def extension_disable(self, ctx: commands.Context, name: str): + async def extension_disable(self, ctx: SlashContext, name: str): s = db.Session() es = s.query(db.ExtensionState).get((name, ctx.guild.id)) - if not es or not es.state: + if not es: raise BadArgument() - es.state = False - s.add(es) - s.commit() - s.close() - await ctx.message.add_reaction("\U0001f44d") + elif not es.state: + message = "Extension already disabled" + else: + es.state = False + s.add(es) + s.commit() + s.close() + message = "\U0001f44d" + await ctx.send(content=message) + + @commands.group("extension", pass_context=True) + async def extension(self, ctx: commands.Context): + pass @extension.group("loaded", pass_context=True) @commands.is_owner() @@ -152,6 +152,9 @@ class Extension(commands.Cog): s.commit() s.close() + def cog_unload(self): + slash.remove_cog_commands(self) + def setup(bot): logger.info(f"Loading...")