diff --git a/EDTuser.py b/base.py similarity index 61% rename from EDTuser.py rename to base.py index fdca167..06b35d6 100644 --- a/EDTuser.py +++ b/base.py @@ -2,27 +2,30 @@ import datetime import requests from EDTcalendar import Calendar from feedparser import parse +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String, Boolean, Date KFET_URL = "http://kfet.bdeinfo.org/orders" +Base = declarative_base() def get_now(): return datetime.datetime.now(datetime.timezone.utc).astimezone(tz=None) -class User: - def __init__(self, user_id: int, language: str): - self.id = user_id - self.language = language - self.resources = None - self.nt = False - self.nt_time = 20 - self.nt_cooldown = 20 - self.nt_last = get_now() - self.kfet = None - self.await_cmd = str() - self.tomuss_rss = str() - self.tomuss_last = str() +class User(Base): + __tablename__ = "user" + id = Column(Integer, primary_key=True, unique=True) + language = Column(String, default="") + resources = Column(Integer) + nt = Column(Boolean, default=False) + nt_time = Column(Integer, default=20) + nt_cooldown = Column(Integer, default=20) + nt_last = Column(Date, default=get_now) + kfet = Column(Integer, default=0) + await_cmd = Column(String, default="") + tomuss_rss = Column(String) + tomuss_last = Column(String) def calendar(self, time: str = "", pass_week: bool = False): return Calendar(time, self.resources, pass_week=pass_week) @@ -55,4 +58,12 @@ class User: entry = [e for e in parse(self.tomuss_rss).entries] if not self.tomuss_last: return entry - return entry[self.tomuss_last:] + tomuss_last = 0 + for i,e in enumerate(entry): + if str(e) == self.tomuss_last: + tomuss_last = i+1 + break + return entry[tomuss_last:] + + def __repr__(self): + return f"" diff --git a/bot.py b/bot.py index 0bd297a..bea69e4 100644 --- a/bot.py +++ b/bot.py @@ -2,7 +2,8 @@ import asyncio import datetime import hashlib import logging -import shelve +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker import re import requests from asyncio import sleep @@ -16,7 +17,7 @@ from aiogram.utils import markdown from aiogram.utils.callback_data import CallbackData from aiogram.utils.exceptions import MessageIsTooLong from EDTcalendar import Calendar -from EDTuser import User, KFET_URL +from base import User, KFET_URL, Base from lang import lang from ics.parse import ParseError from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema @@ -24,11 +25,13 @@ from pyzbar.pyzbar import decode from PIL import Image from feedparser import parse - +tables = False if not isdir("logs"): mkdir("logs") if not isdir("calendars"): mkdir("calendars") +if not isfile("edt.db"): + tables = True logger = logging.getLogger("TelegramEDT") log_date = datetime.datetime.now(datetime.timezone.utc).astimezone(tz=None).date() @@ -48,6 +51,11 @@ TIMES = ["", "day", "next", "week", "next week"] bot = Bot(token=API_TOKEN) posts_cb = CallbackData("post", "id", "action") dp = Dispatcher(bot) +engine = create_engine("sqlite:///edt.db") +Session = sessionmaker(bind=engine) +session = Session() +if tables: + Base.metadata.create_all(engine) dbL = RLock() @@ -57,8 +65,7 @@ def get_now(): def have_await_cmd(msg: types.Message): with dbL: - with shelve.open("edt", writeback=True) as db: - return db[str(msg.from_user.id)].await_cmd + return session.query(User).filter_by(id=msg.from_user.id).first().await_cmd def edt_key(): @@ -70,50 +77,48 @@ def edt_key(): def calendar(time: str, user_id: int): with dbL: - with shelve.open("edt", writeback=True) as db: - user = db[str(user_id)] - if not user.resources: - return lang(user, "edt_err_set") - elif time not in TIMES: - return lang(user, "edt_err_choice") - return str(user.calendar(time)) + user = session.query(User).filter_by(id=user_id).first() + if not user.resources: + return lang(user, "edt_err_set") + elif time not in TIMES: + return lang(user, "edt_err_choice") + return str(user.calendar(time)) async def notif(): while True: with dbL: - with shelve.open("edt", writeback=True) as db: - for u in db: - nt = None - kf = None - tm = None - try: - nt = db[u].get_notif() - kf = db[u].get_kfet() - tm = db[u].get_tomuss() - except: - pass - - if nt: - await bot.send_message(int(u), lang(db[u], "notif_event")+str(nt), parse_mode=ParseMode.MARKDOWN) - if kf: - if kf == 1: - kf = lang(db[u], "kfet") - elif kf == 2: - kf = lang(db[u], "kfet_prb") - else: - kf = lang(db[u], "kfet_err") - await bot.send_message(int(u), kf, parse_mode=ParseMode.MARKDOWN) - if tm: - for i in tm: - msg = markdown.text( - markdown.bold(i.title), - markdown.code(i.summary.replace("
", "\n").replace("", "").replace("", "")), - sep="\n" - ) - await bot.send_message(int(u), msg, parse_mode=ParseMode.MARKDOWN) - db[u].tomuss_last = i + for u in session.query(User).all(): + nt = None + kf = None + tm = None + try: + nt = u.get_notif() + kf = u.get_kfet() + tm = u.get_tomuss() + except: + pass + if nt: + await bot.send_message(u.id, lang(u, "notif_event")+str(nt), parse_mode=ParseMode.MARKDOWN) + if kf: + if kf == 1: + kf = lang(u, "kfet") + elif kf == 2: + kf = lang(u, "kfet_prb") + else: + kf = lang(u, "kfet_err") + await bot.send_message(u.id, kf, parse_mode=ParseMode.MARKDOWN) + if tm: + for i in tm: + msg = markdown.text( + markdown.bold(i.title), + markdown.code(i.summary.replace("
", "\n").replace("", "").replace("", "")), + sep="\n" + ) + await bot.send_message(u.id, msg, parse_mode=ParseMode.MARKDOWN) + u.tomuss_last = str(i) + session.commit() await sleep(60) @@ -134,19 +139,19 @@ async def inline_edt(inline_query: InlineQuery): @dp.message_handler(commands="start") async def start(message: types.Message): - user_id = str(message.from_user.id) + user_id = message.from_user.id await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} start") with dbL: - with shelve.open("edt", writeback=True) as db: - if user_id not in db: - logger.info(f"{message.from_user.username} add to the db") - if message.from_user.locale and message.from_user.locale.language: - lg = message.from_user.locale.language - else: - lg = "" - db[user_id] = User(int(user_id), lg) - user = db[user_id] + if user_id not in session.query(User.id).all(): + logger.info(f"{message.from_user.username} add to the db") + if message.from_user.locale and message.from_user.locale.language: + lg = message.from_user.locale.language + else: + lg = "" + session.add(User(id=user_id, language=lg)) + session.commit() + user = session.query(User).filter_by(id=user_id).first() key = reply_keyboard.ReplyKeyboardMarkup() key.add(reply_keyboard.KeyboardButton("Edt")) key.add(reply_keyboard.KeyboardButton("Kfet")) @@ -162,8 +167,7 @@ async def help_cmd(message: types.Message): await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do help command") with dbL: - with shelve.open("edt", writeback=True) as db: - user = db[str(message.from_user.id)] + user = session.query(User).filter_by(id=message.from_user.id).first() await message.reply(lang(user, "help"), parse_mode=ParseMode.MARKDOWN) @@ -183,47 +187,47 @@ async def edt_query(query: types.CallbackQuery, callback_data: dict): @dp.message_handler(lambda msg: msg.text.lower() == "kfet") async def kfet(message: types.Message): - user_id = str(message.from_user.id) await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do kfet") with dbL: - with shelve.open("edt", writeback=True) as db: - if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 6: - msg = lang(db[user_id], "kfet_close") - else: - msg = lang(db[user_id], "kfet_list") - cmds = requests.get(KFET_URL).json() - if cmds: - for c in cmds: - msg += markdown.code(c) + " " if cmds[c] == "ok" else "" + user = session.query(User).filter_by(id=message.from_user.id).first() + if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 6: + msg = lang(user, "kfet_close") + else: + msg = lang(user, "kfet_list") + cmds = requests.get(KFET_URL).json() + if cmds: + for c in cmds: + msg += markdown.code(c) + " " if cmds[c] == "ok" else "" await message.reply(msg, parse_mode=ParseMode.MARKDOWN) @dp.message_handler(lambda msg: msg.text.lower() == "setkfet") async def kfet_set(message: types.Message): - user_id = str(message.from_user.id) await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do setkfet") with dbL: - with shelve.open("edt", writeback=True) as db: - if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 5: - msg = lang(db[user_id], "kfet_close") - else: - db[user_id].await_cmd = "setkfet" - msg = lang(db[user_id], "kfet_set_await") + user = session.query(User).filter_by(id=message.from_user.id).first() + if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 5: + msg = lang(user, "kfet_close") + else: + user.await_cmd = "setkfet" + msg = lang(user, "kfet_set_await") + session.commit() await message.reply(msg, parse_mode=ParseMode.MARKDOWN) @dp.message_handler(lambda msg: msg.text.lower() == "setedt") async def edt_await(message: types.Message): - user_id = str(message.from_user.id) await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do setedt") with dbL: - with shelve.open("edt", writeback=True) as db: - db[user_id].await_cmd = "setedt" - await message.reply(lang(db[user_id], "setedt_wait"), parse_mode=ParseMode.MARKDOWN) + user = session.query(User).filter_by(id=message.from_user.id).first() + user.await_cmd = "setedt" + session.commit() + + await message.reply(lang(user, "setedt_wait"), parse_mode=ParseMode.MARKDOWN) @dp.message_handler(lambda msg: msg.text.lower() == "settomuss") @@ -232,9 +236,11 @@ async def edt_await(message: types.Message): await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do settomuss") with dbL: - with shelve.open("edt", writeback=True) as db: - db[user_id].await_cmd = "settomuss" - await message.reply(lang(db[user_id], "settomuss_wait"), parse_mode=ParseMode.MARKDOWN) + user = session.query(User).filter_by(id=message.from_user.id).first() + user.await_cmd = "settomuss" + session.commit() + + await message.reply(lang(user, "settomuss_wait"), parse_mode=ParseMode.MARKDOWN) @dp.message_handler(commands="getedt") @@ -243,116 +249,115 @@ async def edt_geturl(message: types.Message): await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do getedt command") with dbL: - with shelve.open("edt", writeback=True) as db: - if db[user_id].resources: - await message.reply(db[user_id].resources) - else: - await message.reply(lang(db[user_id], "getedt_err")) + user = session.query(User).filter_by(id=message.from_user.id).first() + if user.resources: + await message.reply(user.resources) + else: + await message.reply(lang(user, "getedt_err")) @dp.message_handler(lambda msg: msg.text.lower() == "notif") async def notif_cmd(message: types.Message): - user_id = str(message.from_user.id) await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do notif") key = InlineKeyboardMarkup() for i, n in enumerate(["Toggle", "Time", "Cooldown"]): key.add(InlineKeyboardButton(n, callback_data=posts_cb.new(id=i, action=n.lower()))) with dbL: - with shelve.open("edt", writeback=True) as db: - msg = lang(db[user_id], "notif_info").format(db[user_id].nt, db[user_id].nt_time, db[user_id].nt_cooldown) + user = session.query(User).filter_by(id=message.from_user.id).first() + msg = lang(user, "notif_info").format(user.nt, user.nt_time, user.nt_cooldown) await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key) @dp.callback_query_handler(posts_cb.filter(action=["toggle", "time", "cooldown"])) async def notif_query(query: types.CallbackQuery, callback_data: dict): - user_id = str(query.from_user.id) await query.message.chat.do(types.ChatActions.TYPING) logger.info(f"{query.message.from_user.username} do notif query") with dbL: - with shelve.open("edt", writeback=True) as db: - if callback_data["action"] == "toggle": - if db[user_id].nt: - res = False - else: - res = True + user = session.query(User).filter_by(id=query.from_user.id).first() + if callback_data["action"] == "toggle": + if user.nt: + res = False + else: + res = True - db[user_id].nt = res - msg = lang(db[user_id], "notif_set").format(res) + user.nt = res + msg = lang(user, "notif_set").format(res) - elif callback_data["action"] in ["time", "cooldown"]: - db[user_id].await_cmd = callback_data["action"] - msg = lang(db[user_id], "notif_await") + elif callback_data["action"] in ["time", "cooldown"]: + user.await_cmd = callback_data["action"] + msg = lang(user, "notif_await") + session.commit() await query.message.reply(msg, parse_mode=ParseMode.MARKDOWN) @dp.message_handler(lambda msg: have_await_cmd(msg), content_types=[ContentType.TEXT, ContentType.PHOTO]) async def await_cmd(message: types.message): - user_id = str(message.from_user.id) await message.chat.do(types.ChatActions.TYPING) msg = None with dbL: - with shelve.open("edt", writeback=True) as db: - logger.info(f"{message.from_user.username} do awaited commande: {db[user_id].await_cmd}") - if db[user_id].await_cmd == "setedt": - url = str() - if message.photo: - file_path = await bot.get_file(message.photo[0].file_id) - file_url = f"https://api.telegram.org/file/bot{API_TOKEN}/{file_path['file_path']}" - qr = decode(Image.open(requests.get(file_url, stream=True).raw)) - if qr: - url = str(qr[0].data) - elif message.text: - msg_url = re.findall( - "http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+", message.text) - if msg_url: - url = msg_url[0] + user = session.query(User).filter_by(id=message.from_user.id).first() + logger.info(f"{message.from_user.username} do awaited commande: {user.await_cmd}") + if user.await_cmd == "setedt": + url = str() + if message.photo: + file_path = await bot.get_file(message.photo[0].file_id) + file_url = f"https://api.telegram.org/file/bot{API_TOKEN}/{file_path['file_path']}" + qr = decode(Image.open(requests.get(file_url, stream=True).raw)) + if qr: + url = str(qr[0].data) + elif message.text: + msg_url = re.findall( + "http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+", message.text) + if msg_url: + url = msg_url[0] - if url: - resources = url[url.find("resources") + 10:][:4] - elif message.text: - resources = message.text + if url: + resources = url[url.find("resources") + 10:][:4] + elif message.text: + resources = message.text - try: - Calendar("", int(resources)) - except (ParseError, ConnectionError, InvalidSchema, MissingSchema, ValueError, UnboundLocalError): - msg = lang(db[user_id], "setedt_err_res") + try: + Calendar("", int(resources)) + except (ParseError, ConnectionError, InvalidSchema, MissingSchema, ValueError, UnboundLocalError): + msg = lang(user, "setedt_err_res") + else: + user.resources = int(resources) + msg = lang(user, "setedt") + + elif user.await_cmd == "setkfet": + try: + int(message.text) + except ValueError: + msg = lang(user, "err_num") + else: + user.kfet = int(message.text) + msg = lang(user, "kfet_set") + + elif user.await_cmd == "settomuss": + if not len(parse(message.text).entries): + msg = lang(user, "settomuss_error") + else: + user.tomuss_rss = message.text + msg = lang(user, "settomuss") + + elif user.await_cmd in ["time", "cooldown"]: + try: + value = int(message.text) + except ValueError: + msg = lang(user, "err_num") + else: + if user.await_cmd == "time": + user.nt_time = value else: - db[user_id].resources = int(resources) - msg = lang(db[user_id], "setedt") + user.nt_cooldown = value - elif db[user_id].await_cmd == "setkfet": - try: - int(message.text) - except ValueError: - msg = lang(db[user_id], "err_num") - else: - db[user_id].kfet = int(message.text) - msg = lang(db[user_id], "kfet_set") + msg = lang(user, "notif_time_cooldown").format(user.await_cmd[6:], value) - elif db[user_id].await_cmd == "settomuss": - if not len(parse(message.text).entries): - msg = lang(db[user_id], "settomuss_error") - else: - db[user_id].tomuss_rss = message.text - msg = lang(db[user_id], "settomuss") - - elif db[user_id].await_cmd in ["time", "cooldown"]: - try: - value = int(message.text) - except ValueError: - msg = lang(db[user_id], "err_num") - else: - if db[user_id].await_cmd == "time": - db[user_id].nt_time = value - else: - db[user_id].nt_cooldown = value - - msg = lang(db[user_id], "notif_time_cooldown").format(db[user_id].await_cmd[6:], value) - - if db[user_id].await_cmd: - db[user_id].await_cmd = str() + if user.await_cmd: + user.await_cmd = str() + session.commit() if msg: await message.reply(msg, parse_mode=ParseMode.MARKDOWN) @@ -396,12 +401,14 @@ async def get_db(message: types.Message): logger.info(f"{message.from_user.username} do getdb command") if message.from_user.id == ADMIN_ID: with dbL: - with shelve.open("edt") as db: - msg = markdown.text( - markdown.italic("db:"), - markdown.code(dict(db)), - sep="\n" - ) + users = dict() + for u in session.query(User).all(): + users[u] = u.__dict__ + msg = markdown.text( + markdown.italic("db:"), + markdown.code(users), + sep="\n" + ) await message.reply(msg, parse_mode=ParseMode.MARKDOWN) diff --git a/lang.py b/lang.py index d05f46e..6e2a05d 100644 --- a/lang.py +++ b/lang.py @@ -1,5 +1,5 @@ import json -from EDTuser import User +from base import User LANG = ["en"] diff --git a/requirements.txt b/requirements.txt index 9c57d26..38f7955 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +sqlalchemy feedparser aiogram==2.3 aiohttp==3.6.0