1
0
Fork 0

Add list, enable and disable extension commands

This commit is contained in:
Ethanell 2020-11-05 13:56:00 +01:00
parent f818205ad2
commit 05237dbe7b
4 changed files with 117 additions and 7 deletions

View file

@ -1,5 +1,4 @@
{"prefix": "!", {"prefix": "!",
"token": "GOOD_BOT_TOKEN", "token": "GOOD_BOT_TOKEN",
"admin_id": 1234567890,
"db": "postgresql://usr:pass@localhost:5432/sqlalchemy" "db": "postgresql://usr:pass@localhost:5432/sqlalchemy"
} }

26
db/Extension.py Normal file
View file

@ -0,0 +1,26 @@
from db import Base
from sqlalchemy import Column, String, Boolean, ForeignKey, BigInteger
from sqlalchemy.orm import relationship
class Extension(Base):
__tablename__ = "extension"
name = Column(String, primary_key=True)
default_state = Column(Boolean, nullable=False, default=True)
extension_state = relationship("ExtensionState", backref="extension")
def __init__(self, name: int, default_state: bool = True):
self.name = name
self.default_state = default_state
class ExtensionState(Base):
__tablename__ = "extension_state"
extension_name = Column(String, ForeignKey("extension.name"), primary_key=True)
guild_id = Column(BigInteger, nullable=False, primary_key=True)
state = Column(Boolean, nullable=False, default=True)
def __init__(self, extension_name: str, guild_id: int, state: bool = True):
self.extension_name = extension_name
self.guild_id = guild_id
self.state = state

View file

@ -15,4 +15,5 @@ from db.WarnAction import WarnAction
from db.InviteRole import InviteRole from db.InviteRole import InviteRole
from db.Tomuss import Tomuss from db.Tomuss import Tomuss
from db.PCP import PCP from db.PCP import PCP
from db.Extension import Extension, ExtensionState
Base.metadata.create_all(engine) Base.metadata.create_all(engine)

View file

@ -1,7 +1,10 @@
from traceback import format_exc from traceback import format_exc
from discord.ext import commands from discord.ext import commands
from discord import Embed from discord import Embed, Guild
from discord.ext.commands import MissingPermissions, BadArgument
import db
from administrator.logger import logger from administrator.logger import logger
@ -17,7 +20,6 @@ class Extension(commands.Cog):
return "Manage bot's extensions" return "Manage bot's extensions"
@commands.group("extension", pass_context=True) @commands.group("extension", pass_context=True)
@commands.is_owner()
async def extension(self, ctx: commands.Context): async def extension(self, ctx: commands.Context):
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
await ctx.invoke(self.extension_help) await ctx.invoke(self.extension_help)
@ -25,20 +27,69 @@ class Extension(commands.Cog):
@extension.group("help", pass_context=True) @extension.group("help", pass_context=True)
async def extension_help(self, ctx: commands.Context): async def extension_help(self, ctx: commands.Context):
embed = Embed(title="Extension help") embed = Embed(title="Extension help")
embed.add_field(name="extension list", value="List all loaded extensions", inline=False) if await self.extension_list.can_run(ctx):
embed.add_field(name="extension list", value="List all enabled extensions", inline=False)
if await self.extension_enable.can_run(ctx):
embed.add_field(name="extension enable", value="Enable an extensions", inline=False)
if await self.extension_disable.can_run(ctx):
embed.add_field(name="extension disable", value="Disable an extensions", inline=False)
if await self.extension_load.can_run(ctx):
embed.add_field(name="extension loaded", value="List all loaded extensions", inline=False)
if await self.extension_load.can_run(ctx):
embed.add_field(name="extension load <name>", value="Load an extension", inline=False) embed.add_field(name="extension load <name>", value="Load an extension", inline=False)
if await self.extension_unload.can_run(ctx):
embed.add_field(name="extension unload <name>", value="Unload an extension", inline=False) embed.add_field(name="extension unload <name>", value="Unload an extension", inline=False)
if await self.extension_reload.can_run(ctx):
embed.add_field(name="extension reload <name>", value="Reload an extension", inline=False) embed.add_field(name="extension reload <name>", value="Reload an extension", inline=False)
if not embed.fields:
raise MissingPermissions(None)
await ctx.send(embed=embed) await ctx.send(embed=embed)
@extension.group("list", pass_context=True) @extension.group("list", pass_context=True)
@commands.has_guild_permissions(administrator=True)
async def extension_list(self, ctx: commands.Context): async def extension_list(self, ctx: commands.Context):
s = db.Session()
embed = Embed(title="Extensions list") embed = Embed(title="Extensions list")
for es in s.query(db.ExtensionState).filter(db.ExtensionState.guild_id == ctx.guild.id):
embed.add_field(name=es.extension_name, value="Enable" if es.state else "Disable")
await ctx.send(embed=embed)
@extension.group("enable", pass_context=True)
@commands.has_guild_permissions(administrator=True)
async def extension_enable(self, ctx: commands.Context, name: str):
s = db.Session()
es = s.query(db.ExtensionState).get((name, ctx.guild.id))
if not es or es.state:
raise BadArgument()
es.state = True
s.add(es)
s.commit()
s.close()
await ctx.message.add_reaction("\U0001f44d")
@extension.group("disable", pass_context=True)
@commands.has_guild_permissions(administrator=True)
async def extension_disable(self, ctx: commands.Context, name: str):
s = db.Session()
es = s.query(db.ExtensionState).get((name, ctx.guild.id))
if not es or not es.state:
raise BadArgument()
es.state = False
s.add(es)
s.commit()
s.close()
await ctx.message.add_reaction("\U0001f44d")
@extension.group("loaded", pass_context=True)
@commands.is_owner()
async def extension_loaded(self, ctx: commands.Context):
embed = Embed(title="Extensions loaded")
for extension in self.bot.extensions: for extension in self.bot.extensions:
embed.add_field(name=extension, value="Loaded", inline=False) embed.add_field(name=extension, value="Loaded", inline=False)
await ctx.send(embed=embed) await ctx.send(embed=embed)
@extension.group("load", pass_context=True) @extension.group("load", pass_context=True)
@commands.is_owner()
async def extension_load(self, ctx: commands.Context, name: str): async def extension_load(self, ctx: commands.Context, name: str):
try: try:
self.bot.load_extension(name) self.bot.load_extension(name)
@ -49,6 +100,7 @@ class Extension(commands.Cog):
await ctx.message.add_reaction("\U0001f44d") await ctx.message.add_reaction("\U0001f44d")
@extension.group("unload", pass_context=True) @extension.group("unload", pass_context=True)
@commands.is_owner()
async def extension_unload(self, ctx: commands.Context, name: str): async def extension_unload(self, ctx: commands.Context, name: str):
try: try:
self.bot.unload_extension(name) self.bot.unload_extension(name)
@ -59,6 +111,7 @@ class Extension(commands.Cog):
await ctx.message.add_reaction("\U0001f44d") await ctx.message.add_reaction("\U0001f44d")
@extension.group("reload", pass_context=True) @extension.group("reload", pass_context=True)
@commands.is_owner()
async def extension_reload(self, ctx: commands.Context, name: str): async def extension_reload(self, ctx: commands.Context, name: str):
try: try:
self.bot.unload_extension(name) self.bot.unload_extension(name)
@ -69,6 +122,37 @@ class Extension(commands.Cog):
else: else:
await ctx.message.add_reaction("\U0001f44d") await ctx.message.add_reaction("\U0001f44d")
@commands.Cog.listener()
async def on_ready(self):
s = db.Session()
for guild in self.bot.guilds:
for extension in filter(lambda x: x not in ["Extension", "Help"], self.bot.cogs):
e = s.query(db.Extension).get(extension)
if not e:
s.add(db.Extension(extension))
s.commit()
es = s.query(db.ExtensionState).get((extension, guild.id))
if not es:
s.add(db.ExtensionState(extension, guild.id))
s.commit()
s.close()
@commands.Cog.listener()
async def on_guild_join(self, guild: Guild):
s = db.Session()
for extension in s.query(db.Extension).all():
s.add(db.ExtensionState(extension.name, guild.id))
s.commit()
s.close()
@commands.Cog.listener()
async def on_guild_remove(self, guild: Guild):
s = db.Session()
for es in s.query(db.ExtensionState).filter(db.ExtensionState.guild_id == guild.id):
s.delete(es)
s.commit()
s.close()
def setup(bot): def setup(bot):
logger.info(f"Loading...") logger.info(f"Loading...")