Add worker count and type database connections.

main
Joshua Potter 2023-12-09 16:53:17 -07:00
parent 3ab4f893b7
commit 58815d3ae5
2 changed files with 31 additions and 20 deletions

View File

@ -1,5 +1,6 @@
import argparse import argparse
import asyncio import asyncio
from dataclasses import dataclass
from typing import List from typing import List
import aiohttp import aiohttp
@ -11,35 +12,38 @@ from coach_scraper.database import backup_database, load_languages
from coach_scraper.lichess import Pipeline as LichessPipeline from coach_scraper.lichess import Pipeline as LichessPipeline
from coach_scraper.types import Site from coach_scraper.types import Site
# The number of parallel extraction jobs that are run at a time.
WORKER_COUNT = 10 @dataclass
class Context:
conn: psycopg2._psycopg.connection
detector: LanguageDetector
worker_count: int
user_agent: str
async def _process( async def _process(
site: Site, conn, detector: LanguageDetector, session: aiohttp.ClientSession site: Site,
context: Context,
session: aiohttp.ClientSession,
): ):
if site == Site.CHESSCOM: if site == Site.CHESSCOM:
await ChesscomPipeline(worker_count=WORKER_COUNT).process( await ChesscomPipeline(worker_count=context.worker_count).process(
conn, detector, session context.conn, context.detector, session
) )
elif site == Site.LICHESS: elif site == Site.LICHESS:
await LichessPipeline(worker_count=WORKER_COUNT).process( await LichessPipeline(worker_count=context.worker_count).process(
conn, detector, session context.conn, context.detector, session
) )
else: else:
assert False, f"Encountered unknown site: {site}." assert False, f"Encountered unknown site: {site}."
async def _entrypoint( async def _entrypoint(context: Context, sites: List[Site]):
conn, detector: LanguageDetector, user_agent: str, sites: List[Site]
):
"""Top-level entrypoint that dispatches a pipeline per requested site.""" """Top-level entrypoint that dispatches a pipeline per requested site."""
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
headers={"User-Agent": f"BoardWise coach-scraper ({user_agent})"} headers={"User-Agent": f"BoardWise coach-scraper ({context.user_agent})"}
) as session: ) as session:
await asyncio.gather( await asyncio.gather(*[_process(site, context, session) for site in sites])
*[_process(site, conn, detector, session) for site in sites]
)
def main(): def main():
@ -67,6 +71,9 @@ def main():
], ],
) )
# Other.
parser.add_argument("--workers", type=int, default=5)
args = parser.parse_args() args = parser.parse_args()
detector = LanguageDetectorBuilder.from_all_languages().build() detector = LanguageDetectorBuilder.from_all_languages().build()
@ -84,9 +91,12 @@ def main():
load_languages(conn) load_languages(conn)
asyncio.run( asyncio.run(
_entrypoint( _entrypoint(
conn=conn, Context(
detector=detector, conn=conn,
user_agent=args.user_agent, detector=detector,
user_agent=args.user_agent,
worker_count=args.workers,
),
sites=list(map(Site, set(args.site))), sites=list(map(Site, set(args.site))),
) )
) )

View File

@ -3,6 +3,7 @@ import sys
from datetime import datetime from datetime import datetime
from typing import List, Literal from typing import List, Literal
import psycopg2
from typing_extensions import TypedDict from typing_extensions import TypedDict
from coach_scraper.locale import Locale, locale_to_str, native_to_locale from coach_scraper.locale import Locale, locale_to_str, native_to_locale
@ -52,7 +53,7 @@ class Row(TypedDict, total=False):
bullet: int bullet: int
def load_languages(conn): def load_languages(conn: psycopg2._psycopg.connection):
"""Load all known languages into the languages table.""" """Load all known languages into the languages table."""
cursor = None cursor = None
try: try:
@ -77,7 +78,7 @@ def load_languages(conn):
cursor.close() cursor.close()
def backup_database(conn): def backup_database(conn: psycopg2._psycopg.connection):
"""Creates a backup of the export table. """Creates a backup of the export table.
Simply copies the table at time of invocation into another table with a Simply copies the table at time of invocation into another table with a
@ -113,7 +114,7 @@ def backup_database(conn):
cursor.close() cursor.close()
def upsert_row(conn, row: Row): def upsert_row(conn: psycopg2._psycopg.connection, row: Row):
"""Upsert the specified `Row` into the database table.""" """Upsert the specified `Row` into the database table."""
cursor = None cursor = None
try: try: