1
0
Fork 0

Updating check for slash commands

This commit is contained in:
Ethanell 2021-02-03 14:46:22 +01:00
parent 8c1b3af1eb
commit 1b5d8410ae
4 changed files with 84 additions and 16 deletions

View file

@ -1,4 +1,8 @@
import functools
from discord import Permissions
from discord.ext import commands from discord.ext import commands
from discord.ext.commands import NoPrivateMessage, NotOwner, MissingPermissions
import db import db
@ -8,12 +12,61 @@ class ExtensionDisabled(commands.CheckFailure):
def is_enabled(): def is_enabled():
async def check(ctx: commands.Context): def check(func):
if ctx.command.cog and ctx.guild: @functools.wraps(func)
s = db.Session() async def wrapped(*args):
es = s.query(db.ExtensionState).get((ctx.command.cog.qualified_name, ctx.guild.id)) ctx = args[1]
s.close() if ctx.guild:
if es and not es.state: s = db.Session()
raise ExtensionDisabled() es = s.query(db.ExtensionState).get((args[0].qualified_name, ctx.guild.id))
return True s.close()
return commands.check(check) if es and not es.state:
raise ExtensionDisabled()
return await func(*args)
return wrapped
return check
def is_owner():
def check(func):
@functools.wraps(func)
async def wrapped(*args):
ctx = args[1]
if not await ctx._discord.is_owner(ctx.author):
raise NotOwner('You do not own this bot.')
return await func(*args)
return wrapped
return check
def guild_only():
def check(func):
@functools.wraps(func)
async def wrapped(*args):
if args[1].guild is None:
raise NoPrivateMessage()
return await func(*args)
return wrapped
return check
def has_permissions(**perms):
invalid = set(perms) - set(Permissions.VALID_FLAGS)
if invalid:
raise TypeError('Invalid permission(s): %s' % (', '.join(invalid)))
def check(func):
@functools.wraps(func)
async def wrapped(*args):
ctx = args[1]
ch = ctx.channel
permissions = ch.permissions_for(ctx.author)
missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]
if not missing:
return await func(*args)
raise MissingPermissions(missing)
return wrapped
return check

View file

@ -8,6 +8,7 @@ from discord_slash.utils import manage_commands
import db import db
from administrator import slash from administrator import slash
from administrator.check import has_permissions
from administrator.logger import logger from administrator.logger import logger
@ -24,7 +25,7 @@ class Extension(commands.Cog):
return "Manage bot's extensions" return "Manage bot's extensions"
@cog_ext.cog_subcommand(base="extension", name="list", description="List all enabled extensions") @cog_ext.cog_subcommand(base="extension", name="list", description="List all enabled extensions")
@commands.has_guild_permissions(administrator=True) @has_permissions(administrator=True)
async def extension_list(self, ctx: SlashContext): async def extension_list(self, ctx: SlashContext):
s = db.Session() s = db.Session()
embed = Embed(title="Extensions list") embed = Embed(title="Extensions list")
@ -38,7 +39,7 @@ class Extension(commands.Cog):
description="Enable an extensions", description="Enable an extensions",
options=[manage_commands.create_option("extension", "The extension to enable", options=[manage_commands.create_option("extension", "The extension to enable",
SlashCommandOptionType.STRING, True)]) SlashCommandOptionType.STRING, True)])
@commands.has_guild_permissions(administrator=True) @has_permissions(administrator=True)
async def extension_enable(self, ctx: SlashContext, name: str): async def extension_enable(self, ctx: SlashContext, name: str):
s = db.Session() s = db.Session()
es = s.query(db.ExtensionState).get((name, ctx.guild.id)) es = s.query(db.ExtensionState).get((name, ctx.guild.id))
@ -59,7 +60,7 @@ class Extension(commands.Cog):
description="Disable an extensions", description="Disable an extensions",
options=[manage_commands.create_option("extension", "The extension to disable", options=[manage_commands.create_option("extension", "The extension to disable",
SlashCommandOptionType.STRING, True)]) SlashCommandOptionType.STRING, True)])
@commands.has_guild_permissions(administrator=True) @has_permissions(administrator=True)
async def extension_disable(self, ctx: SlashContext, name: str): async def extension_disable(self, ctx: SlashContext, name: str):
s = db.Session() s = db.Session()
es = s.query(db.ExtensionState).get((name, ctx.guild.id)) es = s.query(db.ExtensionState).get((name, ctx.guild.id))

View file

@ -3,6 +3,7 @@ from discord import Member, Embed, Forbidden
from discord_slash import cog_ext, SlashContext, SlashCommandOptionType from discord_slash import cog_ext, SlashContext, SlashCommandOptionType
from discord_slash.utils import manage_commands from discord_slash.utils import manage_commands
from administrator.check import is_enabled, guild_only, has_permissions
from administrator.logger import logger from administrator.logger import logger
from administrator import db, slash from administrator import db, slash
from administrator.utils import event_is_enabled from administrator.utils import event_is_enabled
@ -20,8 +21,11 @@ class Greetings(commands.Cog):
def description(self): def description(self):
return "Setup join and leave message" return "Setup join and leave message"
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="help", @cog_ext.cog_subcommand(base="greetings", name="help",
description="Help about greetings") description="Help about greetings")
@is_enabled()
@guild_only()
@has_permissions(manage_guild=True)
async def greetings_help(self, ctx: SlashContext): async def greetings_help(self, ctx: SlashContext):
embed = Embed(title="Greetings help") embed = Embed(title="Greetings help")
embed.add_field(name="set <join/leave> <message>", value="Set the greetings message\n" embed.add_field(name="set <join/leave> <message>", value="Set the greetings message\n"
@ -31,7 +35,7 @@ class Greetings(commands.Cog):
embed.add_field(name="toggle <join/leave>", value="Enable or disable the greetings message", inline=False) embed.add_field(name="toggle <join/leave>", value="Enable or disable the greetings message", inline=False)
await ctx.send(embeds=[embed]) await ctx.send(embeds=[embed])
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="set", @cog_ext.cog_subcommand(base="greetings", name="set",
description="Set the greetings message\n`{}` will be replace by the username", description="Set the greetings message\n`{}` will be replace by the username",
options=[ options=[
manage_commands.create_option("type", "The join or leave message", manage_commands.create_option("type", "The join or leave message",
@ -41,6 +45,9 @@ class Greetings(commands.Cog):
manage_commands.create_option("message", "The message", SlashCommandOptionType.STRING, manage_commands.create_option("message", "The message", SlashCommandOptionType.STRING,
True) True)
]) ])
@is_enabled()
@guild_only()
@has_permissions(manage_guild=True)
async def greetings_set(self, ctx: SlashContext, message_type: str, message: str): async def greetings_set(self, ctx: SlashContext, message_type: str, message: str):
s = db.Session() s = db.Session()
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first() m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()
@ -52,12 +59,15 @@ class Greetings(commands.Cog):
s.commit() s.commit()
await ctx.send(content="\U0001f44d") await ctx.send(content="\U0001f44d")
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="show", @cog_ext.cog_subcommand(base="greetings", name="show",
description="Show the greetings message", description="Show the greetings message",
options=[manage_commands.create_option("type", "The join or leave message", options=[manage_commands.create_option("type", "The join or leave message",
SlashCommandOptionType.STRING, True, SlashCommandOptionType.STRING, True,
[manage_commands.create_choice("join", "join"), [manage_commands.create_choice("join", "join"),
manage_commands.create_choice("leave", "leave")])]) manage_commands.create_choice("leave", "leave")])])
@is_enabled()
@guild_only()
@has_permissions(manage_guild=True)
async def greetings_show(self, ctx: SlashContext, message_type: str): async def greetings_show(self, ctx: SlashContext, message_type: str):
s = db.Session() s = db.Session()
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first() m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()
@ -70,12 +80,15 @@ class Greetings(commands.Cog):
else: else:
await ctx.send(content=m.leave_msg(str(ctx.author))) await ctx.send(content=m.leave_msg(str(ctx.author)))
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="toggle", @cog_ext.cog_subcommand(base="greetings", name="toggle",
description="Enable or disable the greetings message", description="Enable or disable the greetings message",
options=[manage_commands.create_option("type", "The join or leave message", options=[manage_commands.create_option("type", "The join or leave message",
SlashCommandOptionType.STRING, True, SlashCommandOptionType.STRING, True,
[manage_commands.create_choice("join", "join"), [manage_commands.create_choice("join", "join"),
manage_commands.create_choice("leave", "leave")])]) manage_commands.create_choice("leave", "leave")])])
@is_enabled()
@guild_only()
@has_permissions(manage_guild=True)
async def greetings_toggle(self, ctx: SlashContext, message_type: str): async def greetings_toggle(self, ctx: SlashContext, message_type: str):
s = db.Session() s = db.Session()
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first() m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()

View file

@ -3,6 +3,7 @@ async-timeout==3.0.1
attrs==20.2.0 attrs==20.2.0
chardet==3.0.4 chardet==3.0.4
discord==1.0.1 discord==1.0.1
discord-py-slash-command==1.0.8.5
discord.py==1.5.1 discord.py==1.5.1
feedparser==6.0.2 feedparser==6.0.2
idna==2.10 idna==2.10