Updating check for slash commands
This commit is contained in:
parent
8c1b3af1eb
commit
1b5d8410ae
4 changed files with 84 additions and 16 deletions
|
@ -1,4 +1,8 @@
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from discord import Permissions
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
from discord.ext.commands import NoPrivateMessage, NotOwner, MissingPermissions
|
||||||
|
|
||||||
import db
|
import db
|
||||||
|
|
||||||
|
@ -8,12 +12,61 @@ class ExtensionDisabled(commands.CheckFailure):
|
||||||
|
|
||||||
|
|
||||||
def is_enabled():
|
def is_enabled():
|
||||||
async def check(ctx: commands.Context):
|
def check(func):
|
||||||
if ctx.command.cog and ctx.guild:
|
@functools.wraps(func)
|
||||||
|
async def wrapped(*args):
|
||||||
|
ctx = args[1]
|
||||||
|
if ctx.guild:
|
||||||
s = db.Session()
|
s = db.Session()
|
||||||
es = s.query(db.ExtensionState).get((ctx.command.cog.qualified_name, ctx.guild.id))
|
es = s.query(db.ExtensionState).get((args[0].qualified_name, ctx.guild.id))
|
||||||
s.close()
|
s.close()
|
||||||
if es and not es.state:
|
if es and not es.state:
|
||||||
raise ExtensionDisabled()
|
raise ExtensionDisabled()
|
||||||
return True
|
return await func(*args)
|
||||||
return commands.check(check)
|
return wrapped
|
||||||
|
return check
|
||||||
|
|
||||||
|
|
||||||
|
def is_owner():
|
||||||
|
def check(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapped(*args):
|
||||||
|
ctx = args[1]
|
||||||
|
if not await ctx._discord.is_owner(ctx.author):
|
||||||
|
raise NotOwner('You do not own this bot.')
|
||||||
|
return await func(*args)
|
||||||
|
return wrapped
|
||||||
|
return check
|
||||||
|
|
||||||
|
|
||||||
|
def guild_only():
|
||||||
|
def check(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapped(*args):
|
||||||
|
if args[1].guild is None:
|
||||||
|
raise NoPrivateMessage()
|
||||||
|
return await func(*args)
|
||||||
|
return wrapped
|
||||||
|
return check
|
||||||
|
|
||||||
|
|
||||||
|
def has_permissions(**perms):
|
||||||
|
invalid = set(perms) - set(Permissions.VALID_FLAGS)
|
||||||
|
if invalid:
|
||||||
|
raise TypeError('Invalid permission(s): %s' % (', '.join(invalid)))
|
||||||
|
|
||||||
|
def check(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapped(*args):
|
||||||
|
ctx = args[1]
|
||||||
|
ch = ctx.channel
|
||||||
|
permissions = ch.permissions_for(ctx.author)
|
||||||
|
|
||||||
|
missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]
|
||||||
|
|
||||||
|
if not missing:
|
||||||
|
return await func(*args)
|
||||||
|
|
||||||
|
raise MissingPermissions(missing)
|
||||||
|
return wrapped
|
||||||
|
return check
|
||||||
|
|
|
@ -8,6 +8,7 @@ from discord_slash.utils import manage_commands
|
||||||
|
|
||||||
import db
|
import db
|
||||||
from administrator import slash
|
from administrator import slash
|
||||||
|
from administrator.check import has_permissions
|
||||||
from administrator.logger import logger
|
from administrator.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,7 +25,7 @@ class Extension(commands.Cog):
|
||||||
return "Manage bot's extensions"
|
return "Manage bot's extensions"
|
||||||
|
|
||||||
@cog_ext.cog_subcommand(base="extension", name="list", description="List all enabled extensions")
|
@cog_ext.cog_subcommand(base="extension", name="list", description="List all enabled extensions")
|
||||||
@commands.has_guild_permissions(administrator=True)
|
@has_permissions(administrator=True)
|
||||||
async def extension_list(self, ctx: SlashContext):
|
async def extension_list(self, ctx: SlashContext):
|
||||||
s = db.Session()
|
s = db.Session()
|
||||||
embed = Embed(title="Extensions list")
|
embed = Embed(title="Extensions list")
|
||||||
|
@ -38,7 +39,7 @@ class Extension(commands.Cog):
|
||||||
description="Enable an extensions",
|
description="Enable an extensions",
|
||||||
options=[manage_commands.create_option("extension", "The extension to enable",
|
options=[manage_commands.create_option("extension", "The extension to enable",
|
||||||
SlashCommandOptionType.STRING, True)])
|
SlashCommandOptionType.STRING, True)])
|
||||||
@commands.has_guild_permissions(administrator=True)
|
@has_permissions(administrator=True)
|
||||||
async def extension_enable(self, ctx: SlashContext, name: str):
|
async def extension_enable(self, ctx: SlashContext, name: str):
|
||||||
s = db.Session()
|
s = db.Session()
|
||||||
es = s.query(db.ExtensionState).get((name, ctx.guild.id))
|
es = s.query(db.ExtensionState).get((name, ctx.guild.id))
|
||||||
|
@ -59,7 +60,7 @@ class Extension(commands.Cog):
|
||||||
description="Disable an extensions",
|
description="Disable an extensions",
|
||||||
options=[manage_commands.create_option("extension", "The extension to disable",
|
options=[manage_commands.create_option("extension", "The extension to disable",
|
||||||
SlashCommandOptionType.STRING, True)])
|
SlashCommandOptionType.STRING, True)])
|
||||||
@commands.has_guild_permissions(administrator=True)
|
@has_permissions(administrator=True)
|
||||||
async def extension_disable(self, ctx: SlashContext, name: str):
|
async def extension_disable(self, ctx: SlashContext, name: str):
|
||||||
s = db.Session()
|
s = db.Session()
|
||||||
es = s.query(db.ExtensionState).get((name, ctx.guild.id))
|
es = s.query(db.ExtensionState).get((name, ctx.guild.id))
|
||||||
|
|
|
@ -3,6 +3,7 @@ from discord import Member, Embed, Forbidden
|
||||||
from discord_slash import cog_ext, SlashContext, SlashCommandOptionType
|
from discord_slash import cog_ext, SlashContext, SlashCommandOptionType
|
||||||
from discord_slash.utils import manage_commands
|
from discord_slash.utils import manage_commands
|
||||||
|
|
||||||
|
from administrator.check import is_enabled, guild_only, has_permissions
|
||||||
from administrator.logger import logger
|
from administrator.logger import logger
|
||||||
from administrator import db, slash
|
from administrator import db, slash
|
||||||
from administrator.utils import event_is_enabled
|
from administrator.utils import event_is_enabled
|
||||||
|
@ -20,8 +21,11 @@ class Greetings(commands.Cog):
|
||||||
def description(self):
|
def description(self):
|
||||||
return "Setup join and leave message"
|
return "Setup join and leave message"
|
||||||
|
|
||||||
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="help",
|
@cog_ext.cog_subcommand(base="greetings", name="help",
|
||||||
description="Help about greetings")
|
description="Help about greetings")
|
||||||
|
@is_enabled()
|
||||||
|
@guild_only()
|
||||||
|
@has_permissions(manage_guild=True)
|
||||||
async def greetings_help(self, ctx: SlashContext):
|
async def greetings_help(self, ctx: SlashContext):
|
||||||
embed = Embed(title="Greetings help")
|
embed = Embed(title="Greetings help")
|
||||||
embed.add_field(name="set <join/leave> <message>", value="Set the greetings message\n"
|
embed.add_field(name="set <join/leave> <message>", value="Set the greetings message\n"
|
||||||
|
@ -31,7 +35,7 @@ class Greetings(commands.Cog):
|
||||||
embed.add_field(name="toggle <join/leave>", value="Enable or disable the greetings message", inline=False)
|
embed.add_field(name="toggle <join/leave>", value="Enable or disable the greetings message", inline=False)
|
||||||
await ctx.send(embeds=[embed])
|
await ctx.send(embeds=[embed])
|
||||||
|
|
||||||
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="set",
|
@cog_ext.cog_subcommand(base="greetings", name="set",
|
||||||
description="Set the greetings message\n`{}` will be replace by the username",
|
description="Set the greetings message\n`{}` will be replace by the username",
|
||||||
options=[
|
options=[
|
||||||
manage_commands.create_option("type", "The join or leave message",
|
manage_commands.create_option("type", "The join or leave message",
|
||||||
|
@ -41,6 +45,9 @@ class Greetings(commands.Cog):
|
||||||
manage_commands.create_option("message", "The message", SlashCommandOptionType.STRING,
|
manage_commands.create_option("message", "The message", SlashCommandOptionType.STRING,
|
||||||
True)
|
True)
|
||||||
])
|
])
|
||||||
|
@is_enabled()
|
||||||
|
@guild_only()
|
||||||
|
@has_permissions(manage_guild=True)
|
||||||
async def greetings_set(self, ctx: SlashContext, message_type: str, message: str):
|
async def greetings_set(self, ctx: SlashContext, message_type: str, message: str):
|
||||||
s = db.Session()
|
s = db.Session()
|
||||||
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()
|
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()
|
||||||
|
@ -52,12 +59,15 @@ class Greetings(commands.Cog):
|
||||||
s.commit()
|
s.commit()
|
||||||
await ctx.send(content="\U0001f44d")
|
await ctx.send(content="\U0001f44d")
|
||||||
|
|
||||||
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="show",
|
@cog_ext.cog_subcommand(base="greetings", name="show",
|
||||||
description="Show the greetings message",
|
description="Show the greetings message",
|
||||||
options=[manage_commands.create_option("type", "The join or leave message",
|
options=[manage_commands.create_option("type", "The join or leave message",
|
||||||
SlashCommandOptionType.STRING, True,
|
SlashCommandOptionType.STRING, True,
|
||||||
[manage_commands.create_choice("join", "join"),
|
[manage_commands.create_choice("join", "join"),
|
||||||
manage_commands.create_choice("leave", "leave")])])
|
manage_commands.create_choice("leave", "leave")])])
|
||||||
|
@is_enabled()
|
||||||
|
@guild_only()
|
||||||
|
@has_permissions(manage_guild=True)
|
||||||
async def greetings_show(self, ctx: SlashContext, message_type: str):
|
async def greetings_show(self, ctx: SlashContext, message_type: str):
|
||||||
s = db.Session()
|
s = db.Session()
|
||||||
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()
|
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()
|
||||||
|
@ -70,12 +80,15 @@ class Greetings(commands.Cog):
|
||||||
else:
|
else:
|
||||||
await ctx.send(content=m.leave_msg(str(ctx.author)))
|
await ctx.send(content=m.leave_msg(str(ctx.author)))
|
||||||
|
|
||||||
@cog_ext.cog_subcommand(base="greetings", guild_ids=[693108780434587708], name="toggle",
|
@cog_ext.cog_subcommand(base="greetings", name="toggle",
|
||||||
description="Enable or disable the greetings message",
|
description="Enable or disable the greetings message",
|
||||||
options=[manage_commands.create_option("type", "The join or leave message",
|
options=[manage_commands.create_option("type", "The join or leave message",
|
||||||
SlashCommandOptionType.STRING, True,
|
SlashCommandOptionType.STRING, True,
|
||||||
[manage_commands.create_choice("join", "join"),
|
[manage_commands.create_choice("join", "join"),
|
||||||
manage_commands.create_choice("leave", "leave")])])
|
manage_commands.create_choice("leave", "leave")])])
|
||||||
|
@is_enabled()
|
||||||
|
@guild_only()
|
||||||
|
@has_permissions(manage_guild=True)
|
||||||
async def greetings_toggle(self, ctx: SlashContext, message_type: str):
|
async def greetings_toggle(self, ctx: SlashContext, message_type: str):
|
||||||
s = db.Session()
|
s = db.Session()
|
||||||
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()
|
m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first()
|
||||||
|
|
|
@ -3,6 +3,7 @@ async-timeout==3.0.1
|
||||||
attrs==20.2.0
|
attrs==20.2.0
|
||||||
chardet==3.0.4
|
chardet==3.0.4
|
||||||
discord==1.0.1
|
discord==1.0.1
|
||||||
|
discord-py-slash-command==1.0.8.5
|
||||||
discord.py==1.5.1
|
discord.py==1.5.1
|
||||||
feedparser==6.0.2
|
feedparser==6.0.2
|
||||||
idna==2.10
|
idna==2.10
|
||||||
|
|
Reference in a new issue