From 58815d3ae5a69cac12436a01e77019a5ac5d16a7 Mon Sep 17 00:00:00 2001 From: Joshua Potter Date: Sat, 9 Dec 2023 16:53:17 -0700 Subject: [PATCH] Add worker count and type database connections. --- coach_scraper/__main__.py | 44 ++++++++++++++++++++++++--------------- coach_scraper/database.py | 7 ++++--- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/coach_scraper/__main__.py b/coach_scraper/__main__.py index 4a760f1..c1f3e58 100644 --- a/coach_scraper/__main__.py +++ b/coach_scraper/__main__.py @@ -1,5 +1,6 @@ import argparse import asyncio +from dataclasses import dataclass from typing import List import aiohttp @@ -11,35 +12,38 @@ from coach_scraper.database import backup_database, load_languages from coach_scraper.lichess import Pipeline as LichessPipeline from coach_scraper.types import Site -# The number of parallel extraction jobs that are run at a time. -WORKER_COUNT = 10 + +@dataclass +class Context: + conn: psycopg2._psycopg.connection + detector: LanguageDetector + worker_count: int + user_agent: str async def _process( - site: Site, conn, detector: LanguageDetector, session: aiohttp.ClientSession + site: Site, + context: Context, + session: aiohttp.ClientSession, ): if site == Site.CHESSCOM: - await ChesscomPipeline(worker_count=WORKER_COUNT).process( - conn, detector, session + await ChesscomPipeline(worker_count=context.worker_count).process( + context.conn, context.detector, session ) elif site == Site.LICHESS: - await LichessPipeline(worker_count=WORKER_COUNT).process( - conn, detector, session + await LichessPipeline(worker_count=context.worker_count).process( + context.conn, context.detector, session ) else: assert False, f"Encountered unknown site: {site}." -async def _entrypoint( - conn, detector: LanguageDetector, user_agent: str, sites: List[Site] -): +async def _entrypoint(context: Context, sites: List[Site]): """Top-level entrypoint that dispatches a pipeline per requested site.""" async with aiohttp.ClientSession( - headers={"User-Agent": f"BoardWise coach-scraper ({user_agent})"} + headers={"User-Agent": f"BoardWise coach-scraper ({context.user_agent})"} ) as session: - await asyncio.gather( - *[_process(site, conn, detector, session) for site in sites] - ) + await asyncio.gather(*[_process(site, context, session) for site in sites]) def main(): @@ -67,6 +71,9 @@ def main(): ], ) + # Other. + parser.add_argument("--workers", type=int, default=5) + args = parser.parse_args() detector = LanguageDetectorBuilder.from_all_languages().build() @@ -84,9 +91,12 @@ def main(): load_languages(conn) asyncio.run( _entrypoint( - conn=conn, - detector=detector, - user_agent=args.user_agent, + Context( + conn=conn, + detector=detector, + user_agent=args.user_agent, + worker_count=args.workers, + ), sites=list(map(Site, set(args.site))), ) ) diff --git a/coach_scraper/database.py b/coach_scraper/database.py index 6705789..abf6ca1 100644 --- a/coach_scraper/database.py +++ b/coach_scraper/database.py @@ -3,6 +3,7 @@ import sys from datetime import datetime from typing import List, Literal +import psycopg2 from typing_extensions import TypedDict from coach_scraper.locale import Locale, locale_to_str, native_to_locale @@ -52,7 +53,7 @@ class Row(TypedDict, total=False): bullet: int -def load_languages(conn): +def load_languages(conn: psycopg2._psycopg.connection): """Load all known languages into the languages table.""" cursor = None try: @@ -77,7 +78,7 @@ def load_languages(conn): cursor.close() -def backup_database(conn): +def backup_database(conn: psycopg2._psycopg.connection): """Creates a backup of the export table. Simply copies the table at time of invocation into another table with a @@ -113,7 +114,7 @@ def backup_database(conn): cursor.close() -def upsert_row(conn, row: Row): +def upsert_row(conn: psycopg2._psycopg.connection, row: Row): """Upsert the specified `Row` into the database table.""" cursor = None try: