Replace lock with safe thread sessions for db
This commit is contained in:
parent
aec7b7043a
commit
4aa52befcc
8 changed files with 59 additions and 37 deletions
12
TelegramEDT/EDTscoped_session.py
Normal file
12
TelegramEDT/EDTscoped_session.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
from sqlalchemy.orm import scoped_session as ss
|
||||
|
||||
|
||||
class scoped_session:
|
||||
def __init__(self, session_factory, scopefunc=None):
|
||||
self.scoped_session = ss(session_factory, scopefunc)
|
||||
|
||||
def __enter__(self):
|
||||
return self.scoped_session
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.scoped_session.remove()
|
|
@ -11,6 +11,7 @@ from TelegramEDT.EDTcalendar import Calendar
|
|||
from TelegramEDT.base import Base, User
|
||||
from TelegramEDT.lang import lang
|
||||
from TelegramEDT.logger import logger
|
||||
from TelegramEDT.EDTscoped_session import scoped_session
|
||||
|
||||
if not isfile("token.ini"):
|
||||
logger.critical("No token specified, impossible to start the bot !")
|
||||
|
@ -23,11 +24,10 @@ bot = Bot(token=API_TOKEN)
|
|||
posts_cb = CallbackData("post", "id", "action")
|
||||
dp = Dispatcher(bot)
|
||||
engine = create_engine("sqlite:///edt.db")
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
Session = scoped_session(session_factory)
|
||||
if not isfile("edt.db"):
|
||||
Base.metadata.create_all(engine)
|
||||
dbL = RLock()
|
||||
|
||||
key = reply_keyboard.ReplyKeyboardMarkup()
|
||||
for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]:
|
||||
|
@ -35,7 +35,7 @@ for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]:
|
|||
|
||||
|
||||
def check_id(user: types.User):
|
||||
with dbL:
|
||||
with Session as session:
|
||||
if (user.id,) not in session.query(User.id).all():
|
||||
logger.info(f"{user.username} add to the db")
|
||||
if user.locale and user.locale.language:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from aiogram import types
|
||||
from aiogram.types import ParseMode
|
||||
|
||||
from TelegramEDT import dbL, dp, key, logger, session, check_id
|
||||
from TelegramEDT import dp, key, logger, Session, check_id
|
||||
from TelegramEDT.base import User
|
||||
from TelegramEDT.lang import lang
|
||||
|
||||
|
@ -12,18 +12,22 @@ async def start(message: types.Message):
|
|||
check_id(message.from_user)
|
||||
await message.chat.do(types.ChatActions.TYPING)
|
||||
logger.info(f"{message.from_user.username} start")
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
await message.reply(lang(user, "welcome"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||
msg = lang(user, "welcome")
|
||||
|
||||
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||
|
||||
|
||||
async def help_cmd(message: types.Message):
|
||||
check_id(message.from_user)
|
||||
await message.chat.do(types.ChatActions.TYPING)
|
||||
logger.info(f"{message.from_user.username} do help command")
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
await message.reply(lang(user, "help"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||
msg = lang(user, "help")
|
||||
|
||||
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||
|
||||
|
||||
def load():
|
||||
|
|
|
@ -11,7 +11,7 @@ from ics.parse import ParseError, string_to_container
|
|||
from pyzbar.pyzbar import decode
|
||||
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
||||
|
||||
from TelegramEDT import API_TOKEN, TIMES, bot, dbL, dp, key, logger, session, check_id, posts_cb
|
||||
from TelegramEDT import API_TOKEN, TIMES, bot, dp, key, logger, Session, check_id, posts_cb
|
||||
from TelegramEDT.EDTcalendar import Calendar
|
||||
from TelegramEDT.base import User
|
||||
from TelegramEDT.lang import lang
|
||||
|
@ -21,7 +21,7 @@ re_url = re.compile(r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a
|
|||
|
||||
|
||||
def calendar(time: str, user_id: int):
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=user_id).first()
|
||||
if not user.resources:
|
||||
return lang(user, "edt_err_set")
|
||||
|
@ -38,7 +38,7 @@ def edt_key():
|
|||
|
||||
|
||||
def have_await_cmd(msg: types.Message):
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
||||
return user and user.await_cmd == "setedt"
|
||||
|
||||
|
@ -76,18 +76,19 @@ async def edt_await(message: types.Message):
|
|||
check_id(message.from_user)
|
||||
await message.chat.do(types.ChatActions.TYPING)
|
||||
logger.info(f"{message.from_user.username} do setedt")
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
user.await_cmd = "setedt"
|
||||
session.commit()
|
||||
msg = lang(user, "setedt_wait")
|
||||
|
||||
await message.reply(lang(user, "setedt_wait"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||
|
||||
|
||||
async def await_cmd(message: types.message):
|
||||
check_id(message.from_user)
|
||||
await message.chat.do(types.ChatActions.TYPING)
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
logger.info(f"{message.from_user.username} do edt awaited command")
|
||||
url = str()
|
||||
|
@ -124,12 +125,14 @@ async def edt_geturl(message: types.Message):
|
|||
check_id(message.from_user)
|
||||
await message.chat.do(types.ChatActions.TYPING)
|
||||
logger.info(f"{message.from_user.username} do getedt command")
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
if user.resources:
|
||||
await message.reply(user.resources, reply_markup=key)
|
||||
msg = user.resources
|
||||
else:
|
||||
await message.reply(lang(user, "getedt_err"), reply_markup=key)
|
||||
msg = lang(user, "getedt_err")
|
||||
|
||||
await message.reply(msg, reply_markup=key)
|
||||
|
||||
|
||||
def load():
|
||||
|
|
|
@ -5,7 +5,7 @@ from aiogram import types
|
|||
from aiogram.types import ParseMode
|
||||
from aiogram.utils import markdown
|
||||
|
||||
from TelegramEDT import dbL, dp, key, logger, session, check_id
|
||||
from TelegramEDT import dp, key, logger, Session, check_id
|
||||
from TelegramEDT.base import User, KFET_URL
|
||||
from TelegramEDT.lang import lang
|
||||
|
||||
|
@ -17,7 +17,7 @@ def get_now():
|
|||
|
||||
|
||||
def have_await_cmd(msg: types.Message):
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
||||
return user and user.await_cmd == "setkfet"
|
||||
|
||||
|
@ -26,7 +26,7 @@ async def kfet(message: types.Message):
|
|||
check_id(message.from_user)
|
||||
await message.chat.do(types.ChatActions.TYPING)
|
||||
logger.info(f"{message.from_user.username} do kfet")
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 6:
|
||||
msg = lang(user, "kfet_close")
|
||||
|
@ -47,7 +47,7 @@ async def kfet_set(message: types.Message):
|
|||
check_id(message.from_user)
|
||||
await message.chat.do(types.ChatActions.TYPING)
|
||||
logger.info(f"{message.from_user.username} do setkfet")
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 5:
|
||||
msg = lang(user, "kfet_close")
|
||||
|
@ -62,7 +62,7 @@ async def kfet_set(message: types.Message):
|
|||
async def await_cmd(message: types.message):
|
||||
check_id(message.from_user)
|
||||
await message.chat.do(types.ChatActions.TYPING)
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
logger.info(f"{message.from_user.username} do awaited command")
|
||||
try:
|
||||
|
|
|
@ -4,7 +4,7 @@ from aiogram import types
|
|||
from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton, ParseMode
|
||||
from aiogram.utils import markdown
|
||||
|
||||
from TelegramEDT import bot, dbL, dp, logger, posts_cb, session, check_id, key
|
||||
from TelegramEDT import bot, dp, logger, posts_cb, Session, check_id, key
|
||||
from TelegramEDT.base import User
|
||||
from TelegramEDT.lang import lang
|
||||
|
||||
|
@ -12,14 +12,14 @@ logger = logger.getChild("notif")
|
|||
|
||||
|
||||
def have_await_cmd(msg: types.Message):
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
||||
return user and user.await_cmd in ["time", "cooldown"]
|
||||
|
||||
|
||||
async def notif():
|
||||
while True:
|
||||
with dbL:
|
||||
with Session as session:
|
||||
for u in session.query(User).all():
|
||||
nt = None
|
||||
kf = None
|
||||
|
@ -62,7 +62,7 @@ async def notif_cmd(message: types.Message):
|
|||
keys = InlineKeyboardMarkup()
|
||||
for i, n in enumerate(["Toggle", "Time", "Cooldown"]):
|
||||
keys.add(InlineKeyboardButton(n, callback_data=posts_cb.new(id=i, action=n.lower())))
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
msg = lang(user, "notif_info").format(user.nt, user.nt_time, user.nt_cooldown)
|
||||
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=keys)
|
||||
|
@ -72,7 +72,7 @@ async def notif_query(query: types.CallbackQuery, callback_data: dict):
|
|||
check_id(query.message.from_user)
|
||||
await query.message.chat.do(types.ChatActions.TYPING)
|
||||
logger.info(f"{query.message.from_user.username} do notif query")
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=query.from_user.id).first()
|
||||
if callback_data["action"] == "toggle":
|
||||
if user.nt:
|
||||
|
@ -94,7 +94,7 @@ async def notif_query(query: types.CallbackQuery, callback_data: dict):
|
|||
async def await_cmd(message: types.message):
|
||||
check_id(message.from_user)
|
||||
await message.chat.do(types.ChatActions.TYPING)
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
logger.info(f"{message.from_user.username} do awaited command")
|
||||
try:
|
||||
|
|
|
@ -2,7 +2,7 @@ from aiogram import types
|
|||
from aiogram.types import ParseMode
|
||||
from feedparser import parse
|
||||
|
||||
from TelegramEDT import dbL, dp, key, logger, session, check_id
|
||||
from TelegramEDT import dp, key, logger, Session, check_id
|
||||
from TelegramEDT.base import User
|
||||
from TelegramEDT.lang import lang
|
||||
|
||||
|
@ -10,7 +10,7 @@ logger = logger.getChild("tomuss")
|
|||
|
||||
|
||||
def have_await_cmd(msg: types.Message):
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
||||
return user and user.await_cmd == "settomuss"
|
||||
|
||||
|
@ -19,18 +19,19 @@ async def settomuss(message: types.Message):
|
|||
check_id(message.from_user)
|
||||
await message.chat.do(types.ChatActions.TYPING)
|
||||
logger.info(f"{message.from_user.username} do settomuss")
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
user.await_cmd = "settomuss"
|
||||
session.commit()
|
||||
msg = lang(user, "settomuss_wait")
|
||||
|
||||
await message.reply(lang(user, "settomuss_wait"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||
|
||||
|
||||
async def await_cmd(message: types.message):
|
||||
check_id(message.from_user)
|
||||
await message.chat.do(types.ChatActions.TYPING)
|
||||
with dbL:
|
||||
with Session as session:
|
||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||
logger.info(f"{message.from_user.username} do awaited command")
|
||||
if not len(parse(message.text).entries):
|
||||
|
@ -40,6 +41,7 @@ async def await_cmd(message: types.message):
|
|||
msg = lang(user, "settomuss")
|
||||
user.await_cmd = str()
|
||||
session.commit()
|
||||
|
||||
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from aiogram.types import ParseMode
|
|||
from aiogram.utils import markdown
|
||||
from aiogram.utils.exceptions import MessageIsTooLong
|
||||
|
||||
from TelegramEDT import ADMIN_ID, bot, dbL, dp, key, logger, session, check_id
|
||||
from TelegramEDT import ADMIN_ID, bot, dp, key, logger, Session, check_id
|
||||
from TelegramEDT.base import User
|
||||
|
||||
logger = logger.getChild("tools")
|
||||
|
@ -49,7 +49,7 @@ async def get_db(message: types.Message):
|
|||
check_id(message.from_user)
|
||||
logger.info(f"{message.from_user.username} do getdb command")
|
||||
if message.from_user.id == ADMIN_ID:
|
||||
with dbL:
|
||||
with Session as session:
|
||||
users = dict()
|
||||
for u in session.query(User).all():
|
||||
users[u] = u.__dict__
|
||||
|
@ -75,7 +75,8 @@ async def eval_cmd(message: types.Message):
|
|||
|
||||
async def errors(*args, **partial_data):
|
||||
if "This Session's transaction has been rolled back due to a previous exception during flush" in args:
|
||||
session.rollback()
|
||||
with Session as session:
|
||||
session.rollback()
|
||||
msg = markdown.text(
|
||||
markdown.bold("⚠️ An error occurred:"),
|
||||
markdown.code(args),
|
||||
|
|
Reference in a new issue