1
0
Fork 0

Switch to SQLAlchemy

This commit is contained in:
Ethanell 2019-12-25 17:57:59 +01:00
parent f54dd1c856
commit 5d49fec02a
4 changed files with 195 additions and 176 deletions

View file

@ -2,27 +2,30 @@ import datetime
import requests import requests
from EDTcalendar import Calendar from EDTcalendar import Calendar
from feedparser import parse from feedparser import parse
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, Boolean, Date
KFET_URL = "http://kfet.bdeinfo.org/orders" KFET_URL = "http://kfet.bdeinfo.org/orders"
Base = declarative_base()
def get_now(): def get_now():
return datetime.datetime.now(datetime.timezone.utc).astimezone(tz=None) return datetime.datetime.now(datetime.timezone.utc).astimezone(tz=None)
class User: class User(Base):
def __init__(self, user_id: int, language: str): __tablename__ = "user"
self.id = user_id id = Column(Integer, primary_key=True, unique=True)
self.language = language language = Column(String, default="")
self.resources = None resources = Column(Integer)
self.nt = False nt = Column(Boolean, default=False)
self.nt_time = 20 nt_time = Column(Integer, default=20)
self.nt_cooldown = 20 nt_cooldown = Column(Integer, default=20)
self.nt_last = get_now() nt_last = Column(Date, default=get_now)
self.kfet = None kfet = Column(Integer, default=0)
self.await_cmd = str() await_cmd = Column(String, default="")
self.tomuss_rss = str() tomuss_rss = Column(String)
self.tomuss_last = str() tomuss_last = Column(String)
def calendar(self, time: str = "", pass_week: bool = False): def calendar(self, time: str = "", pass_week: bool = False):
return Calendar(time, self.resources, pass_week=pass_week) return Calendar(time, self.resources, pass_week=pass_week)
@ -55,4 +58,12 @@ class User:
entry = [e for e in parse(self.tomuss_rss).entries] entry = [e for e in parse(self.tomuss_rss).entries]
if not self.tomuss_last: if not self.tomuss_last:
return entry return entry
return entry[self.tomuss_last:] tomuss_last = 0
for i,e in enumerate(entry):
if str(e) == self.tomuss_last:
tomuss_last = i+1
break
return entry[tomuss_last:]
def __repr__(self):
return f"<User: {self.id}>"

171
bot.py
View file

@ -2,7 +2,8 @@ import asyncio
import datetime import datetime
import hashlib import hashlib
import logging import logging
import shelve from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import re import re
import requests import requests
from asyncio import sleep from asyncio import sleep
@ -16,7 +17,7 @@ from aiogram.utils import markdown
from aiogram.utils.callback_data import CallbackData from aiogram.utils.callback_data import CallbackData
from aiogram.utils.exceptions import MessageIsTooLong from aiogram.utils.exceptions import MessageIsTooLong
from EDTcalendar import Calendar from EDTcalendar import Calendar
from EDTuser import User, KFET_URL from base import User, KFET_URL, Base
from lang import lang from lang import lang
from ics.parse import ParseError from ics.parse import ParseError
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
@ -24,11 +25,13 @@ from pyzbar.pyzbar import decode
from PIL import Image from PIL import Image
from feedparser import parse from feedparser import parse
tables = False
if not isdir("logs"): if not isdir("logs"):
mkdir("logs") mkdir("logs")
if not isdir("calendars"): if not isdir("calendars"):
mkdir("calendars") mkdir("calendars")
if not isfile("edt.db"):
tables = True
logger = logging.getLogger("TelegramEDT") logger = logging.getLogger("TelegramEDT")
log_date = datetime.datetime.now(datetime.timezone.utc).astimezone(tz=None).date() log_date = datetime.datetime.now(datetime.timezone.utc).astimezone(tz=None).date()
@ -48,6 +51,11 @@ TIMES = ["", "day", "next", "week", "next week"]
bot = Bot(token=API_TOKEN) bot = Bot(token=API_TOKEN)
posts_cb = CallbackData("post", "id", "action") posts_cb = CallbackData("post", "id", "action")
dp = Dispatcher(bot) dp = Dispatcher(bot)
engine = create_engine("sqlite:///edt.db")
Session = sessionmaker(bind=engine)
session = Session()
if tables:
Base.metadata.create_all(engine)
dbL = RLock() dbL = RLock()
@ -57,8 +65,7 @@ def get_now():
def have_await_cmd(msg: types.Message): def have_await_cmd(msg: types.Message):
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: return session.query(User).filter_by(id=msg.from_user.id).first().await_cmd
return db[str(msg.from_user.id)].await_cmd
def edt_key(): def edt_key():
@ -70,8 +77,7 @@ def edt_key():
def calendar(time: str, user_id: int): def calendar(time: str, user_id: int):
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: user = session.query(User).filter_by(id=user_id).first()
user = db[str(user_id)]
if not user.resources: if not user.resources:
return lang(user, "edt_err_set") return lang(user, "edt_err_set")
elif time not in TIMES: elif time not in TIMES:
@ -82,28 +88,27 @@ def calendar(time: str, user_id: int):
async def notif(): async def notif():
while True: while True:
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: for u in session.query(User).all():
for u in db:
nt = None nt = None
kf = None kf = None
tm = None tm = None
try: try:
nt = db[u].get_notif() nt = u.get_notif()
kf = db[u].get_kfet() kf = u.get_kfet()
tm = db[u].get_tomuss() tm = u.get_tomuss()
except: except:
pass pass
if nt: if nt:
await bot.send_message(int(u), lang(db[u], "notif_event")+str(nt), parse_mode=ParseMode.MARKDOWN) await bot.send_message(u.id, lang(u, "notif_event")+str(nt), parse_mode=ParseMode.MARKDOWN)
if kf: if kf:
if kf == 1: if kf == 1:
kf = lang(db[u], "kfet") kf = lang(u, "kfet")
elif kf == 2: elif kf == 2:
kf = lang(db[u], "kfet_prb") kf = lang(u, "kfet_prb")
else: else:
kf = lang(db[u], "kfet_err") kf = lang(u, "kfet_err")
await bot.send_message(int(u), kf, parse_mode=ParseMode.MARKDOWN) await bot.send_message(u.id, kf, parse_mode=ParseMode.MARKDOWN)
if tm: if tm:
for i in tm: for i in tm:
msg = markdown.text( msg = markdown.text(
@ -111,9 +116,9 @@ async def notif():
markdown.code(i.summary.replace("<br>", "\n").replace("<b>", "").replace("</b>", "")), markdown.code(i.summary.replace("<br>", "\n").replace("<b>", "").replace("</b>", "")),
sep="\n" sep="\n"
) )
await bot.send_message(int(u), msg, parse_mode=ParseMode.MARKDOWN) await bot.send_message(u.id, msg, parse_mode=ParseMode.MARKDOWN)
db[u].tomuss_last = i u.tomuss_last = str(i)
session.commit()
await sleep(60) await sleep(60)
@ -134,19 +139,19 @@ async def inline_edt(inline_query: InlineQuery):
@dp.message_handler(commands="start") @dp.message_handler(commands="start")
async def start(message: types.Message): async def start(message: types.Message):
user_id = str(message.from_user.id) user_id = message.from_user.id
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} start") logger.info(f"{message.from_user.username} start")
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: if user_id not in session.query(User.id).all():
if user_id not in db:
logger.info(f"{message.from_user.username} add to the db") logger.info(f"{message.from_user.username} add to the db")
if message.from_user.locale and message.from_user.locale.language: if message.from_user.locale and message.from_user.locale.language:
lg = message.from_user.locale.language lg = message.from_user.locale.language
else: else:
lg = "" lg = ""
db[user_id] = User(int(user_id), lg) session.add(User(id=user_id, language=lg))
user = db[user_id] session.commit()
user = session.query(User).filter_by(id=user_id).first()
key = reply_keyboard.ReplyKeyboardMarkup() key = reply_keyboard.ReplyKeyboardMarkup()
key.add(reply_keyboard.KeyboardButton("Edt")) key.add(reply_keyboard.KeyboardButton("Edt"))
key.add(reply_keyboard.KeyboardButton("Kfet")) key.add(reply_keyboard.KeyboardButton("Kfet"))
@ -162,8 +167,7 @@ async def help_cmd(message: types.Message):
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do help command") logger.info(f"{message.from_user.username} do help command")
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: user = session.query(User).filter_by(id=message.from_user.id).first()
user = db[str(message.from_user.id)]
await message.reply(lang(user, "help"), parse_mode=ParseMode.MARKDOWN) await message.reply(lang(user, "help"), parse_mode=ParseMode.MARKDOWN)
@ -183,15 +187,14 @@ async def edt_query(query: types.CallbackQuery, callback_data: dict):
@dp.message_handler(lambda msg: msg.text.lower() == "kfet") @dp.message_handler(lambda msg: msg.text.lower() == "kfet")
async def kfet(message: types.Message): async def kfet(message: types.Message):
user_id = str(message.from_user.id)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do kfet") logger.info(f"{message.from_user.username} do kfet")
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: user = session.query(User).filter_by(id=message.from_user.id).first()
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 6: if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 6:
msg = lang(db[user_id], "kfet_close") msg = lang(user, "kfet_close")
else: else:
msg = lang(db[user_id], "kfet_list") msg = lang(user, "kfet_list")
cmds = requests.get(KFET_URL).json() cmds = requests.get(KFET_URL).json()
if cmds: if cmds:
for c in cmds: for c in cmds:
@ -201,29 +204,30 @@ async def kfet(message: types.Message):
@dp.message_handler(lambda msg: msg.text.lower() == "setkfet") @dp.message_handler(lambda msg: msg.text.lower() == "setkfet")
async def kfet_set(message: types.Message): async def kfet_set(message: types.Message):
user_id = str(message.from_user.id)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do setkfet") logger.info(f"{message.from_user.username} do setkfet")
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: user = session.query(User).filter_by(id=message.from_user.id).first()
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 5: if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 5:
msg = lang(db[user_id], "kfet_close") msg = lang(user, "kfet_close")
else: else:
db[user_id].await_cmd = "setkfet" user.await_cmd = "setkfet"
msg = lang(db[user_id], "kfet_set_await") msg = lang(user, "kfet_set_await")
session.commit()
await message.reply(msg, parse_mode=ParseMode.MARKDOWN) await message.reply(msg, parse_mode=ParseMode.MARKDOWN)
@dp.message_handler(lambda msg: msg.text.lower() == "setedt") @dp.message_handler(lambda msg: msg.text.lower() == "setedt")
async def edt_await(message: types.Message): async def edt_await(message: types.Message):
user_id = str(message.from_user.id)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do setedt") logger.info(f"{message.from_user.username} do setedt")
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: user = session.query(User).filter_by(id=message.from_user.id).first()
db[user_id].await_cmd = "setedt" user.await_cmd = "setedt"
await message.reply(lang(db[user_id], "setedt_wait"), parse_mode=ParseMode.MARKDOWN) session.commit()
await message.reply(lang(user, "setedt_wait"), parse_mode=ParseMode.MARKDOWN)
@dp.message_handler(lambda msg: msg.text.lower() == "settomuss") @dp.message_handler(lambda msg: msg.text.lower() == "settomuss")
@ -232,9 +236,11 @@ async def edt_await(message: types.Message):
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do settomuss") logger.info(f"{message.from_user.username} do settomuss")
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: user = session.query(User).filter_by(id=message.from_user.id).first()
db[user_id].await_cmd = "settomuss" user.await_cmd = "settomuss"
await message.reply(lang(db[user_id], "settomuss_wait"), parse_mode=ParseMode.MARKDOWN) session.commit()
await message.reply(lang(user, "settomuss_wait"), parse_mode=ParseMode.MARKDOWN)
@dp.message_handler(commands="getedt") @dp.message_handler(commands="getedt")
@ -243,59 +249,57 @@ async def edt_geturl(message: types.Message):
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do getedt command") logger.info(f"{message.from_user.username} do getedt command")
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: user = session.query(User).filter_by(id=message.from_user.id).first()
if db[user_id].resources: if user.resources:
await message.reply(db[user_id].resources) await message.reply(user.resources)
else: else:
await message.reply(lang(db[user_id], "getedt_err")) await message.reply(lang(user, "getedt_err"))
@dp.message_handler(lambda msg: msg.text.lower() == "notif") @dp.message_handler(lambda msg: msg.text.lower() == "notif")
async def notif_cmd(message: types.Message): async def notif_cmd(message: types.Message):
user_id = str(message.from_user.id)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do notif") logger.info(f"{message.from_user.username} do notif")
key = InlineKeyboardMarkup() key = InlineKeyboardMarkup()
for i, n in enumerate(["Toggle", "Time", "Cooldown"]): for i, n in enumerate(["Toggle", "Time", "Cooldown"]):
key.add(InlineKeyboardButton(n, callback_data=posts_cb.new(id=i, action=n.lower()))) key.add(InlineKeyboardButton(n, callback_data=posts_cb.new(id=i, action=n.lower())))
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: user = session.query(User).filter_by(id=message.from_user.id).first()
msg = lang(db[user_id], "notif_info").format(db[user_id].nt, db[user_id].nt_time, db[user_id].nt_cooldown) msg = lang(user, "notif_info").format(user.nt, user.nt_time, user.nt_cooldown)
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key) await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
@dp.callback_query_handler(posts_cb.filter(action=["toggle", "time", "cooldown"])) @dp.callback_query_handler(posts_cb.filter(action=["toggle", "time", "cooldown"]))
async def notif_query(query: types.CallbackQuery, callback_data: dict): async def notif_query(query: types.CallbackQuery, callback_data: dict):
user_id = str(query.from_user.id)
await query.message.chat.do(types.ChatActions.TYPING) await query.message.chat.do(types.ChatActions.TYPING)
logger.info(f"{query.message.from_user.username} do notif query") logger.info(f"{query.message.from_user.username} do notif query")
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: user = session.query(User).filter_by(id=query.from_user.id).first()
if callback_data["action"] == "toggle": if callback_data["action"] == "toggle":
if db[user_id].nt: if user.nt:
res = False res = False
else: else:
res = True res = True
db[user_id].nt = res user.nt = res
msg = lang(db[user_id], "notif_set").format(res) msg = lang(user, "notif_set").format(res)
elif callback_data["action"] in ["time", "cooldown"]: elif callback_data["action"] in ["time", "cooldown"]:
db[user_id].await_cmd = callback_data["action"] user.await_cmd = callback_data["action"]
msg = lang(db[user_id], "notif_await") msg = lang(user, "notif_await")
session.commit()
await query.message.reply(msg, parse_mode=ParseMode.MARKDOWN) await query.message.reply(msg, parse_mode=ParseMode.MARKDOWN)
@dp.message_handler(lambda msg: have_await_cmd(msg), content_types=[ContentType.TEXT, ContentType.PHOTO]) @dp.message_handler(lambda msg: have_await_cmd(msg), content_types=[ContentType.TEXT, ContentType.PHOTO])
async def await_cmd(message: types.message): async def await_cmd(message: types.message):
user_id = str(message.from_user.id)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
msg = None msg = None
with dbL: with dbL:
with shelve.open("edt", writeback=True) as db: user = session.query(User).filter_by(id=message.from_user.id).first()
logger.info(f"{message.from_user.username} do awaited commande: {db[user_id].await_cmd}") logger.info(f"{message.from_user.username} do awaited commande: {user.await_cmd}")
if db[user_id].await_cmd == "setedt": if user.await_cmd == "setedt":
url = str() url = str()
if message.photo: if message.photo:
file_path = await bot.get_file(message.photo[0].file_id) file_path = await bot.get_file(message.photo[0].file_id)
@ -317,42 +321,43 @@ async def await_cmd(message: types.message):
try: try:
Calendar("", int(resources)) Calendar("", int(resources))
except (ParseError, ConnectionError, InvalidSchema, MissingSchema, ValueError, UnboundLocalError): except (ParseError, ConnectionError, InvalidSchema, MissingSchema, ValueError, UnboundLocalError):
msg = lang(db[user_id], "setedt_err_res") msg = lang(user, "setedt_err_res")
else: else:
db[user_id].resources = int(resources) user.resources = int(resources)
msg = lang(db[user_id], "setedt") msg = lang(user, "setedt")
elif db[user_id].await_cmd == "setkfet": elif user.await_cmd == "setkfet":
try: try:
int(message.text) int(message.text)
except ValueError: except ValueError:
msg = lang(db[user_id], "err_num") msg = lang(user, "err_num")
else: else:
db[user_id].kfet = int(message.text) user.kfet = int(message.text)
msg = lang(db[user_id], "kfet_set") msg = lang(user, "kfet_set")
elif db[user_id].await_cmd == "settomuss": elif user.await_cmd == "settomuss":
if not len(parse(message.text).entries): if not len(parse(message.text).entries):
msg = lang(db[user_id], "settomuss_error") msg = lang(user, "settomuss_error")
else: else:
db[user_id].tomuss_rss = message.text user.tomuss_rss = message.text
msg = lang(db[user_id], "settomuss") msg = lang(user, "settomuss")
elif db[user_id].await_cmd in ["time", "cooldown"]: elif user.await_cmd in ["time", "cooldown"]:
try: try:
value = int(message.text) value = int(message.text)
except ValueError: except ValueError:
msg = lang(db[user_id], "err_num") msg = lang(user, "err_num")
else: else:
if db[user_id].await_cmd == "time": if user.await_cmd == "time":
db[user_id].nt_time = value user.nt_time = value
else: else:
db[user_id].nt_cooldown = value user.nt_cooldown = value
msg = lang(db[user_id], "notif_time_cooldown").format(db[user_id].await_cmd[6:], value) msg = lang(user, "notif_time_cooldown").format(user.await_cmd[6:], value)
if db[user_id].await_cmd: if user.await_cmd:
db[user_id].await_cmd = str() user.await_cmd = str()
session.commit()
if msg: if msg:
await message.reply(msg, parse_mode=ParseMode.MARKDOWN) await message.reply(msg, parse_mode=ParseMode.MARKDOWN)
@ -396,10 +401,12 @@ async def get_db(message: types.Message):
logger.info(f"{message.from_user.username} do getdb command") logger.info(f"{message.from_user.username} do getdb command")
if message.from_user.id == ADMIN_ID: if message.from_user.id == ADMIN_ID:
with dbL: with dbL:
with shelve.open("edt") as db: users = dict()
for u in session.query(User).all():
users[u] = u.__dict__
msg = markdown.text( msg = markdown.text(
markdown.italic("db:"), markdown.italic("db:"),
markdown.code(dict(db)), markdown.code(users),
sep="\n" sep="\n"
) )
await message.reply(msg, parse_mode=ParseMode.MARKDOWN) await message.reply(msg, parse_mode=ParseMode.MARKDOWN)

View file

@ -1,5 +1,5 @@
import json import json
from EDTuser import User from base import User
LANG = ["en"] LANG = ["en"]

View file

@ -1,3 +1,4 @@
sqlalchemy
feedparser feedparser
aiogram==2.3 aiogram==2.3
aiohttp==3.6.0 aiohttp==3.6.0