Add worker count and type database connections.

main
Joshua Potter 2023-12-09 16:53:17 -07:00
parent 3ab4f893b7
commit 58815d3ae5
2 changed files with 31 additions and 20 deletions

View File

@ -1,5 +1,6 @@
import argparse
import asyncio
from dataclasses import dataclass
from typing import List
import aiohttp
@ -11,35 +12,38 @@ from coach_scraper.database import backup_database, load_languages
from coach_scraper.lichess import Pipeline as LichessPipeline
from coach_scraper.types import Site
# The number of parallel extraction jobs that are run at a time.
WORKER_COUNT = 10
@dataclass
class Context:
conn: psycopg2._psycopg.connection
detector: LanguageDetector
worker_count: int
user_agent: str
async def _process(
site: Site, conn, detector: LanguageDetector, session: aiohttp.ClientSession
site: Site,
context: Context,
session: aiohttp.ClientSession,
):
if site == Site.CHESSCOM:
await ChesscomPipeline(worker_count=WORKER_COUNT).process(
conn, detector, session
await ChesscomPipeline(worker_count=context.worker_count).process(
context.conn, context.detector, session
)
elif site == Site.LICHESS:
await LichessPipeline(worker_count=WORKER_COUNT).process(
conn, detector, session
await LichessPipeline(worker_count=context.worker_count).process(
context.conn, context.detector, session
)
else:
assert False, f"Encountered unknown site: {site}."
async def _entrypoint(
conn, detector: LanguageDetector, user_agent: str, sites: List[Site]
):
async def _entrypoint(context: Context, sites: List[Site]):
"""Top-level entrypoint that dispatches a pipeline per requested site."""
async with aiohttp.ClientSession(
headers={"User-Agent": f"BoardWise coach-scraper ({user_agent})"}
headers={"User-Agent": f"BoardWise coach-scraper ({context.user_agent})"}
) as session:
await asyncio.gather(
*[_process(site, conn, detector, session) for site in sites]
)
await asyncio.gather(*[_process(site, context, session) for site in sites])
def main():
@ -67,6 +71,9 @@ def main():
],
)
# Other.
parser.add_argument("--workers", type=int, default=5)
args = parser.parse_args()
detector = LanguageDetectorBuilder.from_all_languages().build()
@ -84,9 +91,12 @@ def main():
load_languages(conn)
asyncio.run(
_entrypoint(
Context(
conn=conn,
detector=detector,
user_agent=args.user_agent,
worker_count=args.workers,
),
sites=list(map(Site, set(args.site))),
)
)

View File

@ -3,6 +3,7 @@ import sys
from datetime import datetime
from typing import List, Literal
import psycopg2
from typing_extensions import TypedDict
from coach_scraper.locale import Locale, locale_to_str, native_to_locale
@ -52,7 +53,7 @@ class Row(TypedDict, total=False):
bullet: int
def load_languages(conn):
def load_languages(conn: psycopg2._psycopg.connection):
"""Load all known languages into the languages table."""
cursor = None
try:
@ -77,7 +78,7 @@ def load_languages(conn):
cursor.close()
def backup_database(conn):
def backup_database(conn: psycopg2._psycopg.connection):
"""Creates a backup of the export table.
Simply copies the table at time of invocation into another table with a
@ -113,7 +114,7 @@ def backup_database(conn):
cursor.close()
def upsert_row(conn, row: Row):
def upsert_row(conn: psycopg2._psycopg.connection, row: Row):
"""Upsert the specified `Row` into the database table."""
cursor = None
try: