diff --git a/db/Greetings.py b/db/Greetings.py new file mode 100644 index 0000000..87b528a --- /dev/null +++ b/db/Greetings.py @@ -0,0 +1,25 @@ +from discord import Embed + +from db import Base +from sqlalchemy import Column, Integer, Text, Boolean, BigInteger + + +class Greetings(Base): + __tablename__ = "greetings" + id = Column(Integer, primary_key=True) + join_message = Column(Text, nullable=False, default="") + join_enable = Column(Boolean, nullable=False, default=False) + leave_message = Column(Text, nullable=False, default="") + leave_enable = Column(Boolean, nullable=False, default=False) + guild = Column(BigInteger, nullable=False) + + def __init__(self, guild: int): + self.guild = guild + + def join_embed(self, guild_name: str, user: str): + embed = Embed() + embed.add_field(name=f"Welcome to {guild_name} !", value=self.join_message.format(user)) + return embed + + def leave_msg(self, user: str): + return self.leave_message.format(user) diff --git a/db/__init__.py b/db/__init__.py index 9c2c364..a3a0181 100644 --- a/db/__init__.py +++ b/db/__init__.py @@ -6,4 +6,5 @@ engine = create_engine(config.get("db")) Session = sessionmaker(bind=engine) Base = declarative_base() from db.Task import Task +from db.Greetings import Greetings Base.metadata.create_all(engine) diff --git a/extensions/__init__.py b/extensions/__init__.py index 08c7d6f..e3d5bb2 100644 --- a/extensions/__init__.py +++ b/extensions/__init__.py @@ -5,3 +5,4 @@ bot.load_extension("extensions.extension") bot.load_extension("extensions.purge") bot.load_extension("extensions.poll") bot.load_extension("extensions.reminders") +bot.load_extension("extensions.greetings") diff --git a/extensions/greetings.py b/extensions/greetings.py new file mode 100644 index 0000000..fad6c64 --- /dev/null +++ b/extensions/greetings.py @@ -0,0 +1,115 @@ +from discord.ext import commands +from discord import Member, Embed +from discord.ext.commands import BadArgument + +from administrator.logger import logger +from administrator import db, config + + +def check_greetings_message_type(message_type): + if message_type not in ["join", "leave"]: + raise BadArgument() + + +extension_name = "greetings" +logger = logger.getChild(extension_name) + + +class Greetings(commands.Cog): + def __init__(self, bot: commands.Bot): + self.bot = bot + + @commands.group("greetings", pass_context=True) + @commands.guild_only() + @commands.has_permissions(manage_guild=True) + async def greetings(self, ctx: commands.Context): + if ctx.invoked_subcommand is None: + await ctx.invoke(self.greetings_help) + + @greetings.group("help", pass_context=True) + async def greetings_help(self, ctx: commands.Context): + embed = Embed(title="greetings help") + embed.add_field(name="set ", value="Set the greetings message\n" + "`{}` will be replace by the username", + inline=False) + embed.add_field(name="show ", value="Show the greetings message", inline=False) + embed.add_field(name="toggle ", value="Enable or disable the greetings message", inline=False) + await ctx.send(embed=embed) + + @greetings.group("set", pass_context=True) + async def greetings_set(self, ctx: commands.Context, message_type: str): + check_greetings_message_type(message_type) + message = ctx.message.content.replace(config.get("prefix")+"greetings set " + message_type, "").strip() + s = db.Session() + m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first() + if not m: + m = db.Greetings(ctx.guild.id) + s.add(m) + setattr(m, message_type+"_enable", True) + setattr(m, message_type+"_message", message) + s.commit() + await ctx.message.add_reaction("\U0001f44d") + + @greetings.group("show", pass_context=True) + async def greetings_show(self, ctx: commands.Context, message_type: str): + check_greetings_message_type(message_type) + s = db.Session() + m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first() + s.close() + if not m: + await ctx.send(f"No {message_type} message set !") + else: + if message_type == "join": + await ctx.send(embed=m.join_embed(ctx.guild.name, str(ctx.message.author))) + else: + await ctx.send(m.leave_msg(str(ctx.message.author))) + + @greetings.group("toggle", pass_context=True) + async def greetings_toggle(self, ctx: commands.Context, message_type: str): + check_greetings_message_type(message_type) + s = db.Session() + m = s.query(db.Greetings).filter(db.Greetings.guild == ctx.guild.id).first() + if not m: + await ctx.send(f"No {message_type} message set !") + else: + setattr(m, message_type+"_enable", not getattr(m, message_type+"_enable")) + s.commit() + await ctx.send(f"{message_type.title()} message is " + + ("enable" if getattr(m, message_type+"_enable") else "disable")) + s.close() + + @commands.Cog.listener() + async def on_member_join(self, member: Member): + s = db.Session() + m = s.query(db.Greetings).filter(db.Greetings.guild == member.guild.id).first() + s.close() + if m and m.join_enable: + await member.send(embed=m.join_embed(member.guild.name, str(member))) + + @commands.Cog.listener() + async def on_member_remove(self, member: Member): + s = db.Session() + m = s.query(db.Greetings).filter(db.Greetings.guild == member.guild.id).first() + s.close() + if m and m.leave_enable: + await member.guild.system_channel.send(m.leave_msg(str(member))) + + +def setup(bot): + logger.info(f"Loading...") + try: + bot.add_cog(Greetings(bot)) + except Exception as e: + logger.error(f"Error loading: {e}") + else: + logger.info(f"Load successful") + + +def teardown(bot): + logger.info(f"Unloading...") + try: + bot.remove_cog("Greetings") + except Exception as e: + logger.error(f"Error unloading: {e}") + else: + logger.info(f"Unload successful") diff --git a/extensions/help.py b/extensions/help.py index 01988c6..2e71cbf 100644 --- a/extensions/help.py +++ b/extensions/help.py @@ -1,5 +1,5 @@ from discord.ext import commands -from discord.ext.commands import CommandNotFound, MissingRequiredArgument, BadArgument +from discord.ext.commands import CommandNotFound, MissingRequiredArgument, BadArgument, MissingPermissions from administrator.logger import logger from administrator.check import NotOwner @@ -24,7 +24,7 @@ class Help(commands.Cog): await ctx.message.add_reaction("\u2753") elif isinstance(error, MissingRequiredArgument) or isinstance(error, BadArgument): await ctx.message.add_reaction("\u274C") - elif isinstance(error, NotOwner): + elif isinstance(error, NotOwner) or isinstance(error, MissingPermissions): await ctx.message.add_reaction("\u274C") else: await ctx.send("An error occurred !") diff --git a/extensions/reminders.py b/extensions/reminders.py index 245f82e..58e2b67 100644 --- a/extensions/reminders.py +++ b/extensions/reminders.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta from discord.ext import commands from discord import Embed -from discord.ext.commands import CommandNotFound, BadArgument +from discord.ext.commands import BadArgument from discord.ext import tasks from administrator.logger import logger @@ -32,7 +32,7 @@ class Reminders(commands.Cog): @commands.group("reminder", pass_context=True) async def reminder(self, ctx: commands.Context): if ctx.invoked_subcommand is None: - raise CommandNotFound() + await ctx.invoke(self.reminder_help) @reminder.group("help", pass_context=True) async def reminder_help(self, ctx: commands.Context):