Add worker count and type database connections.
parent
3ab4f893b7
commit
58815d3ae5
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
|
@ -11,35 +12,38 @@ from coach_scraper.database import backup_database, load_languages
|
|||
from coach_scraper.lichess import Pipeline as LichessPipeline
|
||||
from coach_scraper.types import Site
|
||||
|
||||
# The number of parallel extraction jobs that are run at a time.
|
||||
WORKER_COUNT = 10
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
conn: psycopg2._psycopg.connection
|
||||
detector: LanguageDetector
|
||||
worker_count: int
|
||||
user_agent: str
|
||||
|
||||
|
||||
async def _process(
|
||||
site: Site, conn, detector: LanguageDetector, session: aiohttp.ClientSession
|
||||
site: Site,
|
||||
context: Context,
|
||||
session: aiohttp.ClientSession,
|
||||
):
|
||||
if site == Site.CHESSCOM:
|
||||
await ChesscomPipeline(worker_count=WORKER_COUNT).process(
|
||||
conn, detector, session
|
||||
await ChesscomPipeline(worker_count=context.worker_count).process(
|
||||
context.conn, context.detector, session
|
||||
)
|
||||
elif site == Site.LICHESS:
|
||||
await LichessPipeline(worker_count=WORKER_COUNT).process(
|
||||
conn, detector, session
|
||||
await LichessPipeline(worker_count=context.worker_count).process(
|
||||
context.conn, context.detector, session
|
||||
)
|
||||
else:
|
||||
assert False, f"Encountered unknown site: {site}."
|
||||
|
||||
|
||||
async def _entrypoint(
|
||||
conn, detector: LanguageDetector, user_agent: str, sites: List[Site]
|
||||
):
|
||||
async def _entrypoint(context: Context, sites: List[Site]):
|
||||
"""Top-level entrypoint that dispatches a pipeline per requested site."""
|
||||
async with aiohttp.ClientSession(
|
||||
headers={"User-Agent": f"BoardWise coach-scraper ({user_agent})"}
|
||||
headers={"User-Agent": f"BoardWise coach-scraper ({context.user_agent})"}
|
||||
) as session:
|
||||
await asyncio.gather(
|
||||
*[_process(site, conn, detector, session) for site in sites]
|
||||
)
|
||||
await asyncio.gather(*[_process(site, context, session) for site in sites])
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -67,6 +71,9 @@ def main():
|
|||
],
|
||||
)
|
||||
|
||||
# Other.
|
||||
parser.add_argument("--workers", type=int, default=5)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
detector = LanguageDetectorBuilder.from_all_languages().build()
|
||||
|
@ -84,9 +91,12 @@ def main():
|
|||
load_languages(conn)
|
||||
asyncio.run(
|
||||
_entrypoint(
|
||||
Context(
|
||||
conn=conn,
|
||||
detector=detector,
|
||||
user_agent=args.user_agent,
|
||||
worker_count=args.workers,
|
||||
),
|
||||
sites=list(map(Site, set(args.site))),
|
||||
)
|
||||
)
|
||||
|
|
|
@ -3,6 +3,7 @@ import sys
|
|||
from datetime import datetime
|
||||
from typing import List, Literal
|
||||
|
||||
import psycopg2
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from coach_scraper.locale import Locale, locale_to_str, native_to_locale
|
||||
|
@ -52,7 +53,7 @@ class Row(TypedDict, total=False):
|
|||
bullet: int
|
||||
|
||||
|
||||
def load_languages(conn):
|
||||
def load_languages(conn: psycopg2._psycopg.connection):
|
||||
"""Load all known languages into the languages table."""
|
||||
cursor = None
|
||||
try:
|
||||
|
@ -77,7 +78,7 @@ def load_languages(conn):
|
|||
cursor.close()
|
||||
|
||||
|
||||
def backup_database(conn):
|
||||
def backup_database(conn: psycopg2._psycopg.connection):
|
||||
"""Creates a backup of the export table.
|
||||
|
||||
Simply copies the table at time of invocation into another table with a
|
||||
|
@ -113,7 +114,7 @@ def backup_database(conn):
|
|||
cursor.close()
|
||||
|
||||
|
||||
def upsert_row(conn, row: Row):
|
||||
def upsert_row(conn: psycopg2._psycopg.connection, row: Row):
|
||||
"""Upsert the specified `Row` into the database table."""
|
||||
cursor = None
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue