diff --git a/.gitignore b/.gitignore index b6e4761..ed7db69 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +# db +data/moderation_log.db \ No newline at end of file diff --git a/LICENSE b/LICENSE index c5448b1..147a872 100644 --- a/LICENSE +++ b/LICENSE @@ -19,3 +19,6 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +This notice applies to all files except for cogs/admin.py and cogs/utils/context.py, which are under the +MPL 2.0, which can be found at http://mozilla.org/MPL/2.0/. diff --git a/bot.py b/bot.py index a4904e0..375482b 100644 --- a/bot.py +++ b/bot.py @@ -8,6 +8,8 @@ from pretty_help import PrettyHelp from dotenv import load_dotenv + +from cogs.utils import context from logger import getLogger, set_global_logging_level from curation_validator import get_launch_commands_bluebot, validate_curation, CurationType @@ -28,10 +30,16 @@ PENDING_FIXES_CHANNEL = int(os.getenv('PENDING_FIXES_CHANNEL')) NOTIFY_ME_CHANNEL = int(os.getenv('NOTIFY_ME_CHANNEL')) GOD_USER = int(os.getenv('GOD_USER')) -NOTIFICATION_SQUAD_ID = int(os.getenv('NOTIFICATION_SQUAD_ID')) BOT_GUY = int(os.getenv('BOT_GUY')) - -bot = commands.Bot(command_prefix="-", help_command=PrettyHelp(color=discord.Color.red())) +NOTIFICATION_SQUAD_ID = int(os.getenv('NOTIFICATION_SQUAD_ID')) +TIMEOUT_ID = int(os.getenv('TIMEOUT_ID')) + +intents = discord.Intents.default() +intents.members = True +bot = commands.Bot(command_prefix="-", + help_command=PrettyHelp(color=discord.Color.red()), + case_insensitive=False, + intents=intents) COOL_CRAB = "<:cool_crab:587188729362513930>" EXTREME_EMOJI_ID = 778145279714918400 @@ -43,12 +51,16 @@ async def on_ready(): @bot.event async def on_message(message: discord.Message): - await bot.process_commands(message) + await process_commands(message) await forward_ping(message) await notify_me(message) await check_curation_in_message(message, dry_run=False) +async def process_commands(message): + ctx = await bot.get_context(message, cls=context.Context) + await bot.invoke(ctx) + @bot.event async def on_command_error(ctx: discord.ext.commands.Context, error: Exception): if isinstance(error, commands.MaxConcurrencyReached): @@ -58,6 +70,16 @@ async def on_command_error(ctx: discord.ext.commands.Context, error: Exception): await ctx.channel.send("Insufficient permissions.") return elif isinstance(error, commands.CommandNotFound): + await ctx.channel.send(f"Command {ctx.invoked_with} not found.") + return + elif isinstance(error, commands.UserInputError): + await ctx.send("Invalid input.") + return + elif isinstance(error, commands.NoPrivateMessage): + try: + await ctx.author.send('This command cannot be used in direct messages.') + except discord.Forbidden: + pass return else: reply_channel: discord.TextChannel = bot.get_channel(BOT_TESTING_CHANNEL) @@ -75,14 +97,15 @@ async def forward_ping(message: discord.Message): async def notify_me(message: discord.Message): - notification_squad = message.guild.get_role(NOTIFICATION_SQUAD_ID) - if message.channel is bot.get_channel(NOTIFY_ME_CHANNEL): - if "unnotify me" in message.content.lower(): - l.debug(f"Removed role from {message.author.id}") - await message.author.remove_roles(notification_squad) - elif "notify me" in message.content.lower(): - l.debug(f"Gave role to {message.author.id}") - await message.author.add_roles(notification_squad) + if message.guild is not None: + notification_squad = message.guild.get_role(NOTIFICATION_SQUAD_ID) + if message.channel is bot.get_channel(NOTIFY_ME_CHANNEL): + if "unnotify me" in message.content.lower(): + l.debug(f"Removed role from {message.author.id}") + await message.author.remove_roles(notification_squad) + elif "notify me" in message.content.lower(): + l.debug(f"Gave role to {message.author.id}") + await message.author.add_roles(notification_squad) async def check_curation_in_message(message: discord.Message, dry_run: bool = True): @@ -216,5 +239,8 @@ async def predicate(ctx): bot.load_extension('cogs.curation') bot.load_extension('cogs.info') bot.load_extension('cogs.utilities') +bot.load_extension('cogs.moderation') +bot.load_extension('cogs.admin') + l.info(f"starting the bot...") bot.run(TOKEN) diff --git a/cogs/admin.py b/cogs/admin.py new file mode 100644 index 0000000..283f914 --- /dev/null +++ b/cogs/admin.py @@ -0,0 +1,142 @@ +from discord.ext import commands +import asyncio +import importlib +import os +import re +import sys +import subprocess + +from discord.utils import get + +from bot import BOT_GUY, l + +"""This code is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/.""" + + +class Admin(commands.Cog): + """Admin-only commands that make the bot dynamic.""" + + def __init__(self, bot): + self.bot = bot + self._last_result = None + self.sessions = set() + + async def run_process(self, command): + try: + process = await asyncio.create_subprocess_shell(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + result = await process.communicate() + except NotImplementedError: + process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + result = await self.bot.loop.run_in_executor(None, process.communicate) + + return [output.decode() for output in result] + + async def cog_check(self, ctx: commands.Context): + return ctx.author.id == BOT_GUY or get(ctx.author.roles, name='Administrator') + + @commands.command(hidden=True) + async def load(self, ctx, *, module): + """Loads a module.""" + try: + self.bot.load_extension(module) + except commands.ExtensionError as e: + await ctx.send(f'{e.__class__.__name__}: {e}') + else: + await ctx.send('\N{OK HAND SIGN}') + + @commands.command(hidden=True) + async def unload(self, ctx, *, module): + """Unloads a module.""" + try: + self.bot.unload_extension(module) + except commands.ExtensionError as e: + await ctx.send(f'{e.__class__.__name__}: {e}') + else: + await ctx.send('\N{OK HAND SIGN}') + + @commands.group(name='reload', hidden=True, invoke_without_command=True) + async def _reload(self, ctx, *, module): + l.debug("reload command issued") + """Reloads a module.""" + try: + self.bot.reload_extension(module) + except commands.ExtensionError as e: + await ctx.send(f'{e.__class__.__name__}: {e}') + else: + await ctx.send('\N{OK HAND SIGN}') + + _GIT_PULL_REGEX = re.compile(r'\s*(?P.+?)\s*\|\s*[0-9]+\s*[+-]+') + + def find_modules_from_git(self, output): + files = self._GIT_PULL_REGEX.findall(output) + ret = [] + for file in files: + root, ext = os.path.splitext(file) + if ext != '.py': + continue + + if root.startswith('cogs/'): + # A submodule is a directory inside the main cog directory for + # my purposes + ret.append((root.count('/') - 1, root.replace('/', '.'))) + + # For reload order, the submodules should be reloaded first + ret.sort(reverse=True) + return ret + + def reload_or_load_extension(self, module): + try: + self.bot.reload_extension(module) + except commands.ExtensionNotLoaded: + self.bot.load_extension(module) + + @_reload.command(name='all', hidden=True) + async def _reload_all(self, ctx): + """Reloads all modules, while pulling from git.""" + + async with ctx.typing(): + stdout, stderr = await self.run_process('git pull') + + # progress and stuff is redirected to stderr in git pull + # however, things like "fast forward" and files + # along with the text "already up-to-date" are in stdout + + if stdout.startswith('Already up to date.'): + return await ctx.send(stdout) + + modules = self.find_modules_from_git(stdout) + mods_text = '\n'.join(f'{index}. `{module}`' for index, (_, module) in enumerate(modules, start=1)) + prompt_text = f'This will update the following modules, are you sure?\n{mods_text}' + confirm = await ctx.prompt(prompt_text, reacquire=False) + if not confirm: + return await ctx.send('Aborting.') + + statuses = [] + for is_submodule, module in modules: + if is_submodule: + try: + actual_module = sys.modules[module] + except KeyError: + statuses.append((ctx.tick(None), module)) + else: + try: + importlib.reload(actual_module) + except Exception as e: + statuses.append((ctx.tick(False), module)) + else: + statuses.append((ctx.tick(True), module)) + else: + try: + self.reload_or_load_extension(module) + except commands.ExtensionError: + statuses.append((ctx.tick(False), module)) + else: + statuses.append((ctx.tick(True), module)) + + await ctx.send('\n'.join(f'{status}: `{module}`' for status, module in statuses)) + + +def setup(bot): + bot.add_cog(Admin(bot)) diff --git a/cogs/moderation.py b/cogs/moderation.py new file mode 100644 index 0000000..74aa425 --- /dev/null +++ b/cogs/moderation.py @@ -0,0 +1,378 @@ +import datetime +import sqlite3 +from typing import Optional, Union + +import discord +from discord import Forbidden, HTTPException, NotFound +from discord.ext import commands, tasks + +from pygicord import Paginator + +from bot import TIMEOUT_ID +from logger import getLogger +from util import TimeDeltaConverter + +l = getLogger("main") + + +class Moderation(commands.Cog, description="Moderation tools."): + + def __init__(self, bot): + self.bot: discord.ext.commands.Bot = bot + self.do_temp_unbans.start() + self.db_path = 'data/moderation_log.db' + + @commands.command(name="ban", brief="Ban a user.", description="Ban a user, and optionally give a reason.") + @commands.has_role("Moderator") + async def ban_command(self, ctx: discord.ext.commands.Context, member: discord.Member, *, reason: Optional[str]): + l.debug(f"ban command issued by {ctx.author.id} on user {member.id}") + await self.ban(member, reason) + await ctx.send(f"{member.display_name} was banned.") + + @commands.command(name="kick", brief="Kick a user.", description="Kick a user, and optionally give a reason.") + @commands.has_role("Moderator") + async def kick_command(self, ctx: discord.ext.commands.Context, member: discord.Member, *, reason: Optional[str]): + l.debug(f"kick command issued by {ctx.author.id} on user {member.id}") + await self.kick(member, reason) + await ctx.send(f"{member.display_name} was kicked.") + + @commands.command(name="warn", brief="Warn a user.", + description="Warn a user and give a reason.") + @commands.has_role("Moderator") + async def warn_command(self, ctx: discord.ext.commands.Context, member: discord.Member, *, reason: Optional[str]): + l.debug(f"warn command issued by {ctx.author.id} on user {member.id}") + await self.warn(member, reason) + await ctx.send(f"{member.display_name} was formally warned.") + + @commands.command(name="tempban", brief="Tempban a user.", + description="Temporarily ban a user, and optionally give a reason. " + "Dates should be formatted as [minutes]m[hours]h[days]d[weeks]w, " + "for example 1m3h or 3h1m for an hour and 3 minutes.") + @commands.has_role("Moderator") + async def tempban_command(self, ctx: discord.ext.commands.Context, member: discord.Member, + duration: TimeDeltaConverter, *, reason: Optional[str]): + l.debug(f"tempban command issued by {ctx.author.id} on user {member.id}") + # The type checker can't understand converters, so we have to do this. + # noinspection PyTypeChecker + await self.tempban(member, duration, reason) + await ctx.send(f"{member.display_name} was banned for {duration}.") + + @commands.command(name="timeout", brief="Timeout a user.", + description="Timeout a user, and optionally give a reason. " + "Dates should be formatted as [minutes]m[hours]h[days]d[weeks]w, " + "for example 1m3h or 3h1m for an hour and 3 minutes.") + @commands.has_role("Moderator") + async def timeout_command(self, ctx: discord.ext.commands.Context, member: discord.Member, + duration: TimeDeltaConverter, *, reason: Optional[str]): + l.debug(f"timeout command issued by {ctx.author.id} on user {member.id}") + # The type checker can't understand converters, so we have to do this. + # noinspection PyTypeChecker + await self.timeout(member, duration, reason) + await ctx.send(f"{member.display_name} was given the timeout role for {duration}.") + + @commands.command(name="untimeout", aliases=["remove-timeout", "undo-timeout"], brief="Unban a user.", + description="Unban a user, and optionally give a reason.") + @commands.has_role("Moderator") + async def untimeout_command(self, ctx: discord.ext.commands.Context, member: discord.Member, *, + reason: Optional[str]): + l.debug(f"untimeout command issued by {ctx.author.id} on user {member.id}") + await self.untimeout(member, reason) + await ctx.send(f"{member.display_name} had their timeout removed.") + + @commands.command(name="unban", brief="Unban a user.", description="Unban a user, and optionally give a reason.") + @commands.has_role("Moderator") + async def unban_command(self, ctx: discord.ext.commands.Context, member: Union[discord.User, discord.Member], *, + reason: Optional[str]): + l.debug(f"unban command issued by {ctx.author.id} on user {member.id}") + await self.unban(member, member.guild, reason) + await ctx.send(f"{member.display_name} was unbanned.") + + @commands.command(name="log", brief="Gives a log of all moderator actions done to a user.", + description="Gives a log of all moderator actions done." + "May need full username/mention.") + @commands.has_role("Moderator") + async def log(self, ctx: discord.ext.commands.Context, user: Optional[Union[discord.User, discord.Member]]): + if user is not None: + l.debug(f"log command issued by {ctx.author.id} on user {user.id}") + else: + l.debug(f"log command issued by {ctx.author.id}") + # We're parsing timestamps, so we need the detect-types part + connection = sqlite3.connect(self.db_path, detect_types=sqlite3.PARSE_DECLTYPES) + c = connection.cursor() + if user is not None: + try: + c.execute("SELECT action, reason, action_date FROM log WHERE user_id = ? AND guild_id = ?", + (user.id, ctx.guild.id)) + events: list[tuple[str, str, datetime]] = c.fetchall() + c.close() + finally: + connection.close() + else: + try: + c.execute("SELECT action, reason, action_date, user_id FROM log") + events: list[tuple[str, str, datetime, int]] = c.fetchall() + c.close() + finally: + connection.close() + if any(x[0] == "Ban" for x in events): + embed_color = discord.Color.red() + elif any(x[0] == "Kick" for x in events): + embed_color = discord.Color.orange() + elif any(x[0] == "Warn" for x in events): + embed_color = discord.Color.gold() + elif any(x[0] == "Timeout" for x in events): + embed_color = discord.Color.blue() + else: + embed_color = discord.Color.green() + pages: list[discord.Embed] + + pages = [] + embed = discord.Embed(color=embed_color) + if user is not None: + embed.set_author(name=user.name, icon_url=user.avatar_url) + max_embeds = 8 + else: + max_embeds = 5 + if events: + for event in events: + if event[0] == "Ban": + event_prefix = '🚫' + elif event[0] == "Unban" or event[0] == "Untimeout": + event_prefix = '↩' + elif event[0] == "Kick": + event_prefix = '👢' + elif event[0] == "Warn": + event_prefix = '⚠️' + elif event[0] == "Timeout": + event_prefix = '🕒' + else: + event_prefix = '' + if len(embed.fields) >= max_embeds: + pages.append(embed) + embed = discord.Embed(color=embed_color) + if user is not None: + embed.set_author(name=user.name, icon_url=user.avatar_url) + time_str = event[2].strftime("%Y-%m-%d %H:%M:%S") + if user is not None: + embed.add_field(name=event_prefix + ' ' + event[0], + value=f"Date: {time_str}\n" + f"Reason: {event[1]}", + inline=False) + else: + temp_user = await self.bot.fetch_user(event[3]) + embed.add_field(name=event_prefix + ' ' + event[0], + value=f"User: {temp_user.name}\n" + f"Date: {time_str}\n" + f"Reason: {event[1]}", + inline=False) + else: + embed.title = "No events found." + + pages.append(embed) + paginator = Paginator(pages=pages) + await paginator.start(ctx) + + # for each tempban in the database, if it's before now, unban by id. + @tasks.loop(seconds=30.0) + async def do_temp_unbans(self): + l.debug("checking for unbans") + connection = sqlite3.connect(self.db_path, detect_types=sqlite3.PARSE_DECLTYPES) + c = connection.cursor() + try: + c.execute( + "SELECT user_id, guild_id, action " + "FROM log " + "WHERE undone = '0' " + "AND unban_date < datetime('now')") + expired_tempbans: list[tuple[int, int, str]] = c.fetchall() + for expired_tempban in expired_tempbans: + guild: discord.Guild = self.bot.get_guild(expired_tempban[1]) + action = expired_tempban[2] + user_id = expired_tempban[0] + if action == "Ban": + try: + user: discord.User = await self.bot.fetch_user(user_id) + await self.unban(user, guild, "Tempban expired") + except NotFound: + self.log_unban(user_id, guild, "Ban") + except HTTPException: + pass + elif action == "Timeout": + member: discord.Member = guild.get_member(user_id) + if member is not None: + await self.untimeout(member, "Timeout expired") + else: + self.log_unban(user_id, guild, "Timeout") + + finally: + c.close() + connection.close() + + @do_temp_unbans.before_loop + async def before_start_unbans(self): + await self.bot.wait_until_ready() + + # This detects if a member rejoins while a timeout would still be applied to them. + @commands.Cog.listener() + async def on_member_join(self, member: discord.Member): + connection = sqlite3.connect(self.db_path) + c = connection.cursor() + try: + c.execute("SELECT EXISTS(SELECT 1 " + "FROM log " + "WHERE user_id = ? " + "AND guild_id = ? " + "AND action = 'Timeout' " + "AND undone = 0)", + (member.id, member.guild.id)) + record = c.fetchone() + if record[0] == 1: + timeout_role = member.guild.get_role(TIMEOUT_ID) + await member.add_roles(timeout_role) + c.close() + finally: + connection.close() + + async def cog_command_error(self, ctx, error): + if isinstance(error, commands.BadUnionArgument): + await ctx.send("Could not get user.") + elif isinstance(error, commands.UserNotFound): + await ctx.send("Could not get user.") + elif isinstance(error, commands.MemberNotFound): + await ctx.send("Could not get member.") + elif isinstance(error, commands.BadArgument): + await ctx.send("Invalid argument.") + + async def ban(self, member: discord.Member, reason: str, dry_run=False): + self.log_user_event("Ban", member, member.guild, reason) + if not dry_run: + await try_dm(member, "You have been permanently banned from the Flashpoint discord server.\n" + f"Reason: {reason}") + await member.ban(reason=reason) + + async def kick(self, member: discord.Member, reason: str, dry_run=False): + self.log_user_event("Kick", member, member.guild, reason) + if not dry_run: + await try_dm(member, "You have been kicked from the Flashpoint discord server.\n" + f"Reason: {reason}") + await member.kick(reason=reason) + + async def warn(self, member: discord.Member, reason: str, dry_run=False): + self.log_user_event("Warn", member, member.guild, reason) + if not dry_run: + await try_dm(member, "You have been formally warned by the moderators of the Flashpoint discord server." + "Another infraction will have steeper consequences.\n" + f"Reason: {reason}") + + async def timeout(self, member: discord.Member, duration: datetime.timedelta, reason: str, dry_run=False): + timeout_role = member.guild.get_role(TIMEOUT_ID) + self.log_tempban("Timeout", member, duration, reason) + if not dry_run: + await try_dm(member, f"You have been put in timeout from the Flashpoint discord server for {duration}." + f"You will not be able to interact any channels.\n" + f"Reason: {reason}") + await member.add_roles(timeout_role) + + async def tempban(self, member: discord.Member, duration: datetime.timedelta, reason: str, dry_run=False): + # The type checker doesn't understand how converters work, so I suppressed the warning here. + # noinspection PyTypeChecker + self.log_tempban("Ban", member, duration, reason) + if not dry_run: + await try_dm(member, f"You have been banned from the Flashpoint discord server for {duration}.\n" + f"Reason: {reason}") + await member.ban(reason=reason) + + async def unban(self, user: Union[discord.User, discord.Member], guild: discord.Guild, reason: str, dry_run=False): + self.log_user_event("Unban", user, guild, reason) + self.log_unban(user.id, guild, "Ban") + if not dry_run: + await try_dm(user, "You have been unbanned from the Flashpoint discord server.\n" + f"Reason: {reason}") + await guild.unban(user, reason=reason) + + async def untimeout(self, member: discord.Member, reason: str, dry_run=False): + timeout_role = member.guild.get_role(TIMEOUT_ID) + self.log_user_event("Untimeout", member, member.guild, reason) + self.log_unban(member.id, member.guild, "Timeout") + if not dry_run: + await try_dm(member, f"Your timeout is over, you can now interact with all channels freely.\n" + f"Reason: {reason}") + await member.remove_roles(timeout_role) + + def log_tempban(self, action: str, member: discord.Member, duration: datetime.timedelta, reason: str): + connection = sqlite3.connect(self.db_path) + c = connection.cursor() + try: + utc_now: datetime.datetime = datetime.datetime.now(tz=datetime.timezone.utc) + c.execute( + "INSERT INTO log (user_id, guild_id, action, reason, action_date, unban_date, undone) " + "VALUES (?, ?, ?, ?, ?, ? , 0)", + (member.id, member.guild.id, action, reason, utc_now, utc_now + duration)) + connection.commit() + finally: + c.close() + connection.close() + + def log_user_event(self, action: str, user: Union[discord.User, discord.Member], guild: discord.Guild, reason: str): + connection = sqlite3.connect(self.db_path) + c = connection.cursor() + try: + utc_now: datetime.datetime = datetime.datetime.now(tz=datetime.timezone.utc) + c.execute("INSERT INTO log (user_id, guild_id, action, reason, action_date)" + "VALUES (?, ?, ?, ?, ?)", + (user.id, guild.id, action, reason, utc_now)) + connection.commit() + finally: + c.close() + connection.close() + + def log_unban(self, user_id: int, guild: discord.Guild, action: str): + connection = sqlite3.connect(self.db_path) + c = connection.cursor() + try: + c.execute("UPDATE log " + "SET undone = 1 " + "WHERE undone = 0 " + "AND user_id = ? " + "AND guild_id = ?" + "AND action = ?", + (user_id, guild.id, action)) + connection.commit() + finally: + c.close() + connection.close() + + def create_moderation_log(self) -> None: + connection = sqlite3.connect(self.db_path) + c = connection.cursor() + try: + c.execute("CREATE TABLE IF NOT EXISTS log (" + "user_id integer NOT NULL," + "guild_id integer NOT NULL," + "action text NOT NULL," + "reason text," + "action_date timestamp NOT NULL," + "unban_date timestamp," + "undone integer);") + finally: + c.close() + connection.close() + return + + +async def try_dm(user: Union[discord.User, discord.Member], message): + try: + await user.send(message) + except Forbidden: + l.debug(f"Not allowed to send message to {user.id}") + except HTTPException: + l.debug(f"Failed to send message to {user.id}") + + +def setup(bot: commands.Bot): + cog = Moderation(bot) + try: + Moderation.create_moderation_log(cog) + bot.add_cog(cog) + except Exception as e: + l.error(f"Error {e} when trying to set up moderation, will not be initialized.") diff --git a/cogs/utils/context.py b/cogs/utils/context.py new file mode 100644 index 0000000..be49166 --- /dev/null +++ b/cogs/utils/context.py @@ -0,0 +1,185 @@ +from discord.ext import commands +import asyncio +import discord +import io + + +"""This code is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/.""" + + +class _ContextDBAcquire: + __slots__ = ('ctx', 'timeout') + + def __init__(self, ctx, timeout): + self.ctx = ctx + self.timeout = timeout + + +class Context(commands.Context): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._db = None + + async def entry_to_code(self, entries): + width = max(len(a) for a, b in entries) + output = ['```'] + for name, entry in entries: + output.append(f'{name:<{width}}: {entry}') + output.append('```') + await self.send('\n'.join(output)) + + async def indented_entry_to_code(self, entries): + width = max(len(a) for a, b in entries) + output = ['```'] + for name, entry in entries: + output.append(f'\u200b{name:>{width}}: {entry}') + output.append('```') + await self.send('\n'.join(output)) + + def __repr__(self): + # we need this for our cache key strategy + return '' + + @property + def session(self): + return self.bot.session + + @discord.utils.cached_property + def replied_reference(self): + ref = self.message.reference + if ref and isinstance(ref.resolved, discord.Message): + return ref.resolved.to_reference() + return None + + async def disambiguate(self, matches, entry): + if len(matches) == 0: + raise ValueError('No results found.') + + if len(matches) == 1: + return matches[0] + + await self.send('There are too many matches... Which one did you mean? **Only say the number**.') + await self.send('\n'.join(f'{index}: {entry(item)}' for index, item in enumerate(matches, 1))) + + def check(m): + return m.content.isdigit() and m.author.id == self.author.id and m.channel.id == self.channel.id + + # only give them 3 tries. + for i in range(3): + try: + message = await self.bot.wait_for('message', check=check, timeout=30.0) + except asyncio.TimeoutError: + raise ValueError('Took too long. Goodbye.') + + index = int(message.content) + try: + return matches[index - 1] + except: + await self.send(f'Please give me a valid number. {2 - i} tries remaining...') + + raise ValueError('Too many tries. Goodbye.') + + + async def prompt(self, message, *, timeout=60.0, delete_after=True, reacquire=True, author_id=None): + """An interactive reaction confirmation dialog. + Parameters + ----------- + message: str + The message to show along with the prompt. + timeout: float + How long to wait before returning. + delete_after: bool + Whether to delete the confirmation message after we're done. + reacquire: bool + Whether to release the database connection and then acquire it + again when we're done. + author_id: Optional[int] + The member who should respond to the prompt. Defaults to the author of the + Context's message. + Returns + -------- + Optional[bool] + ``True`` if explicit confirm, + ``False`` if explicit deny, + ``None`` if deny due to timeout + """ + + if not self.channel.permissions_for(self.me).add_reactions: + raise RuntimeError('Bot does not have Add Reactions permission.') + + fmt = f'{message}\n\nReact with \N{WHITE HEAVY CHECK MARK} to confirm or \N{CROSS MARK} to deny.' + + author_id = author_id or self.author.id + msg = await self.send(fmt) + + confirm = None + + def check(payload): + nonlocal confirm + + if payload.message_id != msg.id or payload.user_id != author_id: + return False + + codepoint = str(payload.emoji) + + if codepoint == '\N{WHITE HEAVY CHECK MARK}': + confirm = True + return True + elif codepoint == '\N{CROSS MARK}': + confirm = False + return True + + return False + + for emoji in ('\N{WHITE HEAVY CHECK MARK}', '\N{CROSS MARK}'): + await msg.add_reaction(emoji) + + try: + await self.bot.wait_for('raw_reaction_add', check=check, timeout=timeout) + except asyncio.TimeoutError: + confirm = None + + try: + + if delete_after: + await msg.delete() + finally: + return confirm + + def tick(self, opt, label=None): + lookup = { + True: '<:greenTick:330090705336664065>', + False: '<:redTick:330090723011592193>', + None: '<:greyTick:563231201280917524>', + } + emoji = lookup.get(opt, '<:redTick:330090723011592193>') + if label is not None: + return f'{emoji}: {label}' + return emoji + + + async def show_help(self, command=None): + """Shows the help command for the specified command if given. + If no command is given, then it'll show help for the current + command. + """ + cmd = self.bot.get_command('help') + command = command or self.command.qualified_name + await self.invoke(cmd, command=command) + + async def safe_send(self, content, *, escape_mentions=True, **kwargs): + """Same as send except with some safe guards. + 1) If the message is too long then it sends a file with the results instead. + 2) If ``escape_mentions`` is ``True`` then it escapes mentions. + """ + if escape_mentions: + content = discord.utils.escape_mentions(content) + + if len(content) > 2000: + fp = io.BytesIO(content.encode()) + kwargs.pop('file', None) + return await self.send(file=discord.File(fp, filename='message_too_long.txt'), **kwargs) + else: + return await self.send(content) \ No newline at end of file diff --git a/example.env b/example.env index 2e9944f..08b0f56 100644 --- a/example.env +++ b/example.env @@ -1,7 +1,7 @@ # .env # private discord key DISCORD_TOKEN= -# These will be random strings of numbers, they'll get printed when the bot is first run. +# channel ids FLASH_GAMES_CHANNEL=0 OTHER_GAMES_CHANNEL=0 ANIMATIONS_CHANNEL=0 @@ -13,6 +13,10 @@ EXCEPTION_CHANNEL=0 BOT_ALERTS_CHANNEL=0 PENDING_FIXES_CHANNEL=0 NOTIFY_ME_CHANNEL=0 +# user ids GOD_USER=0 BOT_GUY=0 +# role ids NOTIFICATION_SQUAD_ID=0 +TIMEOUT_ID=0 + diff --git a/moderation_test.py b/moderation_test.py new file mode 100644 index 0000000..c3181e2 --- /dev/null +++ b/moderation_test.py @@ -0,0 +1,37 @@ +import discord.ext.commands +import pytest +import discord.ext.test as dpytest +from discord.ext import commands +from pretty_help import PrettyHelp + + +@pytest.fixture +def bot(event_loop): + intents = discord.Intents.default() + intents.members = True + bot = commands.Bot(command_prefix="-", + help_command=PrettyHelp(color=discord.Color.red()), + case_insensitive=False, + intents=intents, loop=event_loop) + bot.load_extension('cogs.batch_validate') + bot.load_extension('cogs.troubleshooting') + bot.load_extension('cogs.curation') + bot.load_extension('cogs.info') + bot.load_extension('cogs.utilities') + bot.load_extension('cogs.moderation') + bot.load_extension('cogs.admin') + dpytest.configure(bot) + return bot + + +@pytest.mark.asyncio +async def test_timeout(bot): + guild = bot.guilds[0] + member1 = guild.members[0] + await dpytest.message(f"-timeout {member1.id}") + + +@pytest.mark.asyncio +async def test_foo(bot): + await dpytest.message("!hello") + assert dpytest.verify().message().content("Hello World!") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 76b3f55..7b3d1c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ requests beautifulsoup4 cachetools discord-pretty-help +pygicord fastapi uvicorn diff --git a/util.py b/util.py index 404655c..d249e6e 100644 --- a/util.py +++ b/util.py @@ -1,6 +1,9 @@ +import datetime +import re import zipfile import py7zr +from discord.ext import commands from logger import getLogger @@ -28,6 +31,26 @@ def get_archive_filenames(path: str) -> list[str]: return filenames +time_regex = re.compile(r"(\d{1,5}(?:[.,]?\d{1,5})?)([smhdw])") +time_dict = {"h": 3600, "s": 1, "m": 60, "d": 86400, "w": 604800} + + +class TimeDeltaConverter(commands.Converter): + async def convert(self, ctx, argument): + matches = time_regex.findall(argument.lower()) + seconds = 0 + for v, k in matches: + try: + seconds += time_dict[k] * float(v) + except KeyError: + raise commands.BadArgument("{} is an invalid time-key! h/m/s/d/w are valid!".format(k)) + except ValueError: + raise commands.BadArgument("{} is not a number!".format(v)) + if seconds <= 0: + raise commands.BadArgument("Time must be greater than 0.") + return datetime.timedelta(seconds=seconds) + + class ArchiveTooLargeException(Exception): pass