1
0
Fork 0

Merge pull request #48 from flifloo/extension

Extension
This commit is contained in:
Ethanell 2020-11-05 14:19:21 +01:00
commit 9ad885421e
6 changed files with 152 additions and 16 deletions

View file

@ -1,10 +1,19 @@
from discord.ext import commands from discord.ext import commands
from administrator import config
import db
class NotOwner(commands.CheckFailure): class ExtensionDisabled(commands.CheckFailure):
pass pass
async def is_owner(ctx: commands.Context): def is_enabled():
return ctx.author.id == config.get("admin_id") async def check(ctx: commands.Context):
if ctx.command.cog:
s = db.Session()
es = s.query(db.ExtensionState).get((ctx.command.cog.qualified_name, ctx.guild.id))
s.close()
if es and not es.state:
raise ExtensionDisabled()
return True
return commands.check(check)

View file

@ -1,5 +1,4 @@
{"prefix": "!", {"prefix": "!",
"token": "GOOD_BOT_TOKEN", "token": "GOOD_BOT_TOKEN",
"admin_id": 1234567890,
"db": "postgresql://usr:pass@localhost:5432/sqlalchemy" "db": "postgresql://usr:pass@localhost:5432/sqlalchemy"
} }

26
db/Extension.py Normal file
View file

@ -0,0 +1,26 @@
from db import Base
from sqlalchemy import Column, String, Boolean, ForeignKey, BigInteger
from sqlalchemy.orm import relationship
class Extension(Base):
__tablename__ = "extension"
name = Column(String, primary_key=True)
default_state = Column(Boolean, nullable=False, default=True)
extension_state = relationship("ExtensionState", backref="extension")
def __init__(self, name: int, default_state: bool = True):
self.name = name
self.default_state = default_state
class ExtensionState(Base):
__tablename__ = "extension_state"
extension_name = Column(String, ForeignKey("extension.name"), primary_key=True)
guild_id = Column(BigInteger, nullable=False, primary_key=True)
state = Column(Boolean, nullable=False, default=True)
def __init__(self, extension_name: str, guild_id: int, state: bool = True):
self.extension_name = extension_name
self.guild_id = guild_id
self.state = state

View file

@ -15,4 +15,5 @@ from db.WarnAction import WarnAction
from db.InviteRole import InviteRole from db.InviteRole import InviteRole
from db.Tomuss import Tomuss from db.Tomuss import Tomuss
from db.PCP import PCP from db.PCP import PCP
from db.Extension import Extension, ExtensionState
Base.metadata.create_all(engine) Base.metadata.create_all(engine)

View file

@ -1,6 +1,10 @@
from traceback import format_exc
from discord.ext import commands from discord.ext import commands
from discord import Embed from discord import Embed, Guild
from administrator.check import is_owner from discord.ext.commands import MissingPermissions, BadArgument
import db
from administrator.logger import logger from administrator.logger import logger
@ -16,45 +20,139 @@ class Extension(commands.Cog):
return "Manage bot's extensions" return "Manage bot's extensions"
@commands.group("extension", pass_context=True) @commands.group("extension", pass_context=True)
@commands.check(is_owner)
async def extension(self, ctx: commands.Context): async def extension(self, ctx: commands.Context):
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
embed = Embed(title="Extensions") await ctx.invoke(self.extension_help)
for extension in self.bot.extensions:
embed.add_field(name=extension, value="Loaded", inline=False) @extension.group("help", pass_context=True)
await ctx.send(embed=embed) async def extension_help(self, ctx: commands.Context):
embed = Embed(title="Extension help")
if await self.extension_list.can_run(ctx):
embed.add_field(name="extension list", value="List all enabled extensions", inline=False)
if await self.extension_enable.can_run(ctx):
embed.add_field(name="extension enable", value="Enable an extensions", inline=False)
if await self.extension_disable.can_run(ctx):
embed.add_field(name="extension disable", value="Disable an extensions", inline=False)
if await self.extension_load.can_run(ctx):
embed.add_field(name="extension loaded", value="List all loaded extensions", inline=False)
if await self.extension_load.can_run(ctx):
embed.add_field(name="extension load <name>", value="Load an extension", inline=False)
if await self.extension_unload.can_run(ctx):
embed.add_field(name="extension unload <name>", value="Unload an extension", inline=False)
if await self.extension_reload.can_run(ctx):
embed.add_field(name="extension reload <name>", value="Reload an extension", inline=False)
if not embed.fields:
raise MissingPermissions(None)
await ctx.send(embed=embed)
@extension.group("list", pass_context=True)
@commands.has_guild_permissions(administrator=True)
async def extension_list(self, ctx: commands.Context):
s = db.Session()
embed = Embed(title="Extensions list")
for es in s.query(db.ExtensionState).filter(db.ExtensionState.guild_id == ctx.guild.id):
embed.add_field(name=es.extension_name, value="Enable" if es.state else "Disable")
await ctx.send(embed=embed)
@extension.group("enable", pass_context=True)
@commands.has_guild_permissions(administrator=True)
async def extension_enable(self, ctx: commands.Context, name: str):
s = db.Session()
es = s.query(db.ExtensionState).get((name, ctx.guild.id))
if not es or es.state:
raise BadArgument()
es.state = True
s.add(es)
s.commit()
s.close()
await ctx.message.add_reaction("\U0001f44d")
@extension.group("disable", pass_context=True)
@commands.has_guild_permissions(administrator=True)
async def extension_disable(self, ctx: commands.Context, name: str):
s = db.Session()
es = s.query(db.ExtensionState).get((name, ctx.guild.id))
if not es or not es.state:
raise BadArgument()
es.state = False
s.add(es)
s.commit()
s.close()
await ctx.message.add_reaction("\U0001f44d")
@extension.group("loaded", pass_context=True)
@commands.is_owner()
async def extension_loaded(self, ctx: commands.Context):
embed = Embed(title="Extensions loaded")
for extension in self.bot.extensions:
embed.add_field(name=extension, value="Loaded", inline=False)
await ctx.send(embed=embed)
@extension.group("load", pass_context=True) @extension.group("load", pass_context=True)
@commands.check(is_owner) @commands.is_owner()
async def extension_load(self, ctx: commands.Context, name: str): async def extension_load(self, ctx: commands.Context, name: str):
try: try:
self.bot.load_extension(name) self.bot.load_extension(name)
except Exception as e: except Exception as e:
await ctx.message.add_reaction("\u26a0") await ctx.message.add_reaction("\u26a0")
await ctx.send(f"{e.__class__.__name__}: {e}\n```{format_exc()}```")
else: else:
await ctx.message.add_reaction("\U0001f44d") await ctx.message.add_reaction("\U0001f44d")
@extension.group("unload", pass_context=True) @extension.group("unload", pass_context=True)
@commands.check(is_owner) @commands.is_owner()
async def extension_unload(self, ctx: commands.Context, name: str): async def extension_unload(self, ctx: commands.Context, name: str):
try: try:
self.bot.unload_extension(name) self.bot.unload_extension(name)
except Exception as e: except Exception as e:
await ctx.message.add_reaction("\u26a0") await ctx.message.add_reaction("\u26a0")
await ctx.send(f"{e.__class__.__name__}: {e}\n```{format_exc()}```")
else: else:
await ctx.message.add_reaction("\U0001f44d") await ctx.message.add_reaction("\U0001f44d")
@extension.group("reload", pass_context=True) @extension.group("reload", pass_context=True)
@commands.check(is_owner) @commands.is_owner()
async def extension_reload(self, ctx: commands.Context, name: str): async def extension_reload(self, ctx: commands.Context, name: str):
try: try:
self.bot.unload_extension(name) self.bot.unload_extension(name)
self.bot.load_extension(name) self.bot.load_extension(name)
except Exception as e: except Exception as e:
await ctx.message.add_reaction("\u26a0") await ctx.message.add_reaction("\u26a0")
await ctx.send(f"{e.__class__.__name__}: {e}\n```{format_exc()}```")
else: else:
await ctx.message.add_reaction("\U0001f44d") await ctx.message.add_reaction("\U0001f44d")
@commands.Cog.listener()
async def on_ready(self):
s = db.Session()
for guild in self.bot.guilds:
for extension in filter(lambda x: x not in ["Extension", "Help"], self.bot.cogs):
e = s.query(db.Extension).get(extension)
if not e:
s.add(db.Extension(extension))
s.commit()
es = s.query(db.ExtensionState).get((extension, guild.id))
if not es:
s.add(db.ExtensionState(extension, guild.id))
s.commit()
s.close()
@commands.Cog.listener()
async def on_guild_join(self, guild: Guild):
s = db.Session()
for extension in s.query(db.Extension).all():
s.add(db.ExtensionState(extension.name, guild.id))
s.commit()
s.close()
@commands.Cog.listener()
async def on_guild_remove(self, guild: Guild):
s = db.Session()
for es in s.query(db.ExtensionState).filter(db.ExtensionState.guild_id == guild.id):
s.delete(es)
s.commit()
s.close()
def setup(bot): def setup(bot):
logger.info(f"Loading...") logger.info(f"Loading...")

View file

@ -4,6 +4,7 @@ from discord.ext.commands import CommandNotFound, MissingRequiredArgument, BadAr
NoPrivateMessage, CommandError, NotOwner NoPrivateMessage, CommandError, NotOwner
from administrator import config from administrator import config
from administrator.check import ExtensionDisabled
from administrator.logger import logger from administrator.logger import logger
@ -43,7 +44,9 @@ class Help(commands.Cog):
await ctx.message.add_reaction("\u274C") await ctx.message.add_reaction("\u274C")
elif isinstance(error, NotOwner) or isinstance(error, MissingPermissions)\ elif isinstance(error, NotOwner) or isinstance(error, MissingPermissions)\
or isinstance(error, NoPrivateMessage): or isinstance(error, NoPrivateMessage):
await ctx.message.add_reaction("\u274C") await ctx.message.add_reaction("\U000026D4")
elif isinstance(error, ExtensionDisabled):
await ctx.message.add_reaction("\U0001F6AB")
else: else:
await ctx.send("An error occurred !") await ctx.send("An error occurred !")
raise error raise error