Updating check for slash commands
This commit is contained in:
parent
8c1b3af1eb
commit
1b5d8410ae
4 changed files with 84 additions and 16 deletions
|
@ -1,4 +1,8 @@
|
|||
import functools
|
||||
|
||||
from discord import Permissions
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import NoPrivateMessage, NotOwner, MissingPermissions
|
||||
|
||||
import db
|
||||
|
||||
|
@ -8,12 +12,61 @@ class ExtensionDisabled(commands.CheckFailure):
|
|||
|
||||
|
||||
def is_enabled():
|
||||
async def check(ctx: commands.Context):
|
||||
if ctx.command.cog and ctx.guild:
|
||||
def check(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapped(*args):
|
||||
ctx = args[1]
|
||||
if ctx.guild:
|
||||
s = db.Session()
|
||||
es = s.query(db.ExtensionState).get((ctx.command.cog.qualified_name, ctx.guild.id))
|
||||
es = s.query(db.ExtensionState).get((args[0].qualified_name, ctx.guild.id))
|
||||
s.close()
|
||||
if es and not es.state:
|
||||
raise ExtensionDisabled()
|
||||
return True
|
||||
return commands.check(check)
|
||||
return await func(*args)
|
||||
return wrapped
|
||||
return check
|
||||
|
||||
|
||||
def is_owner():
|
||||
def check(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapped(*args):
|
||||
ctx = args[1]
|
||||
if not await ctx._discord.is_owner(ctx.author):
|
||||
raise NotOwner('You do not own this bot.')
|
||||
return await func(*args)
|
||||
return wrapped
|
||||
return check
|
||||
|
||||
|
||||
def guild_only():
|
||||
def check(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapped(*args):
|
||||
if args[1].guild is None:
|
||||
raise NoPrivateMessage()
|
||||
return await func(*args)
|
||||
return wrapped
|
||||
return check
|
||||
|
||||
|
||||
def has_permissions(**perms):
|
||||
invalid = set(perms) - set(Permissions.VALID_FLAGS)
|
||||
if invalid:
|
||||
raise TypeError('Invalid permission(s): %s' % (', '.join(invalid)))
|
||||
|
||||
def check(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapped(*args):
|
||||
ctx = args[1]
|
||||
ch = ctx.channel
|
||||
permissions = ch.permissions_for(ctx.author)
|
||||
|
||||
missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]
|
||||
|
||||
if not missing:
|
||||
return await func(*args)
|
||||
|
||||
raise MissingPermissions(missing)
|
||||
return wrapped
|
||||
return check
|
||||
|
|
|
@ -8,6 +8,7 @@ from discord_slash.utils import manage_commands
|
|||
|
||||
import db
|
||||
from administrator import slash
|
||||
from administrator.check import has_permissions
|
||||
from administrator.logger import logger
|
||||
|
||||
|
||||
|
@ -24,7 +25,7 @@ class Extension(commands.Cog):
|
|||
return "Manage bot's extensions"
|
||||
|
||||
@cog_ext.cog_subcommand(base="extension", name="list", description="List all enabled extensions")
|
||||
@commands.has_guild_permissions(administrator=True)
|
||||
@has_permissions(administrator=True)
|
||||
async def extension_list(self, ctx: SlashContext):
|
||||
s = db.Session()
|
||||
embed = Embed(title="Extensions list")
|
||||
|
@ -38,7 +39,7 @@ class Extension(commands.Cog):
|
|||
description="Enable an extensions",
|
||||
options=[manage_commands.create_option("extension", "The extension to enable",
|
||||
SlashCommandOptionType.STRING, True)])
|
||||
@commands.has_guild_permissions(administrator=True)
|
||||
@has_permissions(administrator=True)
|
||||
async def extension_enable(self, ctx: SlashContext, name: str):
|
||||
s = db.Session()
|
||||
es = s.query(db.ExtensionState).get((name, ctx.guild.id))
|
||||
|
@ -59,7 +60,7 @@ class Extension(commands.Cog):
|
|||
description="Disable an extensions",
|
||||
options=[manage_commands.create_option("extension", "The extension to disable",
|
||||
SlashCommandOptionType.STRING, True)])
|
||||
@commands.has_guild_permissions(administrator=True)
|
||||
@has_permissions(administrator=True)
|
||||
async def extension_disable(self, ctx: SlashContext, name: str):
|
||||
s = db.Session()
|
||||
es = s.query(db.ExtensionState).get((name, ctx.guild.id))
|
||||
|
|
|
@ -3,6 +3,7 @@ from discord import Member, Embed, Forbidden
|
|||
from discord_slash import cog_ext, SlashContext, SlashCommandOptionType
|
||||
from discord_slash.utils import manage_commands
|
||||
|
||||
from administrator.check import is_enabled, guild_only, has_permissions
|
||||
from administrator.logger import logger
|
||||
from administrator import db, slash
|
||||
from administrator.utils import event_is_enabled
|
||||
|
@ -20,8 +21,11 @@ class Greetings(commands.Cog):
|
|||
def description(self):
|
||||
return "Setup join and leave message"
|
||||
|
||||
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="help",
|
||||
@cog_ext.cog_subcommand(base="greetings", name="help",
|
||||
description="Help about greetings")
|
||||
@is_enabled()
|
||||
@guild_only()
|
||||
@has_permissions(manage_guild=True)
|
||||
async def greetings_help(self, ctx: SlashContext):
|
||||
embed = Embed(title="Greetings help")
|
||||
embed.add_field(name="set <join/leave> <message>", value="Set the greetings message\n"
|
||||
|
@ -31,7 +35,7 @@ class Greetings(commands.Cog):
|
|||
embed.add_field(name="toggle <join/leave>", value="Enable or disable the greetings message", inline=False)
|
||||
await ctx.send(embeds=[embed])
|
||||
|
||||
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="set",
|
||||
@cog_ext.cog_subcommand(base="greetings", name="set",
|
||||
description="Set the greetings message\n`{}` will be replace by the username",
|
||||
options=[
|
||||
manage_commands.create_option("type", "The join or leave message",
|
||||
|
@ -41,6 +45,9 @@ class Greetings(commands.Cog):
|
|||
manage_commands.create_option("message", "The message", SlashCommandOptionType.STRING,
|
||||
True)
|
||||
])
|
||||
@is_enabled()
|
||||
@guild_only()
|
||||
@has_permissions(manage_guild=True)
|
||||
async def greetings_set(self, ctx: SlashContext, message_type: str, message: str):
|
||||
s = db.Session()
|
||||
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()
|
||||
|
@ -52,12 +59,15 @@ class Greetings(commands.Cog):
|
|||
s.commit()
|
||||
await ctx.send(content="\U0001f44d")
|
||||
|
||||
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="show",
|
||||
@cog_ext.cog_subcommand(base="greetings", name="show",
|
||||
description="Show the greetings message",
|
||||
options=[manage_commands.create_option("type", "The join or leave message",
|
||||
SlashCommandOptionType.STRING, True,
|
||||
[manage_commands.create_choice("join", "join"),
|
||||
manage_commands.create_choice("leave", "leave")])])
|
||||
@is_enabled()
|
||||
@guild_only()
|
||||
@has_permissions(manage_guild=True)
|
||||
async def greetings_show(self, ctx: SlashContext, message_type: str):
|
||||
s = db.Session()
|
||||
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()
|
||||
|
@ -70,12 +80,15 @@ class Greetings(commands.Cog):
|
|||
else:
|
||||
await ctx.send(content=m.leave_msg(str(ctx.author)))
|
||||
|
||||
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="toggle",
|
||||
@cog_ext.cog_subcommand(base="greetings", name="toggle",
|
||||
description="Enable or disable the greetings message",
|
||||
options=[manage_commands.create_option("type", "The join or leave message",
|
||||
SlashCommandOptionType.STRING, True,
|
||||
[manage_commands.create_choice("join", "join"),
|
||||
manage_commands.create_choice("leave", "leave")])])
|
||||
@is_enabled()
|
||||
@guild_only()
|
||||
@has_permissions(manage_guild=True)
|
||||
async def greetings_toggle(self, ctx: SlashContext, message_type: str):
|
||||
s = db.Session()
|
||||
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()
|
||||
|
|
|
@ -3,6 +3,7 @@ async-timeout==3.0.1
|
|||
attrs==20.2.0
|
||||
chardet==3.0.4
|
||||
discord==1.0.1
|
||||
discord-py-slash-command==1.0.8.5
|
||||
discord.py==1.5.1
|
||||
feedparser==6.0.2
|
||||
idna==2.10
|
||||
|
|
Reference in a new issue