diff --git a/.coverage b/.coverage index 3f47374..122fe12 100644 Binary files a/.coverage and b/.coverage differ diff --git a/requirements.txt b/requirements.txt index dab5a18..4513817 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,7 @@ pytest==7.4.3 pytest-asyncio==0.21.1 httpx==0.25.2 sqlalchemy>=2.0.35 -aiosqlite>=0.19.0 \ No newline at end of file +aiosqlite>=0.19.0 +aiohttp>=3.9.0 +lxml>=5.0.0 +pillow>=10.0.0 \ No newline at end of file diff --git a/src/config/settings.py b/src/config/settings.py index 31420d9..3149ee5 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -72,6 +72,43 @@ class Settings(BaseSettings): default=3, validation_alias="RETRY_ATTEMPTS" ) + + # NFO / TMDB Settings + tmdb_api_key: Optional[str] = Field( + default=None, + validation_alias="TMDB_API_KEY", + description="TMDB API key for scraping TV show metadata" + ) + nfo_auto_create: bool = Field( + default=False, + validation_alias="NFO_AUTO_CREATE", + description="Automatically create NFO files when scanning series" + ) + nfo_update_on_scan: bool = Field( + default=False, + validation_alias="NFO_UPDATE_ON_SCAN", + description="Update existing NFO files when scanning series" + ) + nfo_download_poster: bool = Field( + default=True, + validation_alias="NFO_DOWNLOAD_POSTER", + description="Download poster.jpg when creating NFO" + ) + nfo_download_logo: bool = Field( + default=True, + validation_alias="NFO_DOWNLOAD_LOGO", + description="Download logo.png when creating NFO" + ) + nfo_download_fanart: bool = Field( + default=True, + validation_alias="NFO_DOWNLOAD_FANART", + description="Download fanart.jpg when creating NFO" + ) + nfo_image_size: str = Field( + default="original", + validation_alias="NFO_IMAGE_SIZE", + description="Image size to download (original, w500, etc.)" + ) @property def allowed_origins(self) -> list[str]: diff --git a/src/core/services/nfo_service.py b/src/core/services/nfo_service.py new file mode 100644 index 0000000..42501af --- /dev/null +++ b/src/core/services/nfo_service.py @@ -0,0 +1,392 @@ +"""NFO service for creating and managing tvshow.nfo files. + +This service orchestrates TMDB API calls, XML generation, and media downloads +to create complete NFO metadata for TV series. + +Example: + >>> nfo_service = NFOService(tmdb_api_key="key", anime_directory="/anime") + >>> await nfo_service.create_tvshow_nfo("Attack on Titan", "/anime/aot", 2013) +""" + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +from src.core.entities.nfo_models import ( + ActorInfo, + ImageInfo, + RatingInfo, + TVShowNFO, + UniqueID, +) +from src.core.services.tmdb_client import TMDBAPIError, TMDBClient +from src.core.utils.image_downloader import ImageDownloader, ImageDownloadError +from src.core.utils.nfo_generator import generate_tvshow_nfo + +logger = logging.getLogger(__name__) + + +class NFOService: + """Service for creating and managing tvshow.nfo files. + + Attributes: + tmdb_client: TMDB API client + image_downloader: Image downloader utility + anime_directory: Base directory for anime series + """ + + def __init__( + self, + tmdb_api_key: str, + anime_directory: str, + image_size: str = "original", + auto_create: bool = True + ): + """Initialize NFO service. + + Args: + tmdb_api_key: TMDB API key + anime_directory: Base anime directory path + image_size: Image size to download (original, w500, etc.) + auto_create: Whether to auto-create NFOs + """ + self.tmdb_client = TMDBClient(api_key=tmdb_api_key) + self.image_downloader = ImageDownloader() + self.anime_directory = Path(anime_directory) + self.image_size = image_size + self.auto_create = auto_create + + async def check_nfo_exists(self, serie_folder: str) -> bool: + """Check if tvshow.nfo exists for a series. + + Args: + serie_folder: Series folder name + + Returns: + True if tvshow.nfo exists + """ + nfo_path = self.anime_directory / serie_folder / "tvshow.nfo" + return nfo_path.exists() + + async def create_tvshow_nfo( + self, + serie_name: str, + serie_folder: str, + year: Optional[int] = None, + download_poster: bool = True, + download_logo: bool = True, + download_fanart: bool = True + ) -> Path: + """Create tvshow.nfo by scraping TMDB. + + Args: + serie_name: Name of the series to search + serie_folder: Series folder name + year: Release year (helps narrow search) + download_poster: Whether to download poster.jpg + download_logo: Whether to download logo.png + download_fanart: Whether to download fanart.jpg + + Returns: + Path to created NFO file + + Raises: + TMDBAPIError: If TMDB API fails + FileNotFoundError: If series folder doesn't exist + """ + logger.info(f"Creating NFO for {serie_name} (year: {year})") + + folder_path = self.anime_directory / serie_folder + if not folder_path.exists(): + raise FileNotFoundError(f"Series folder not found: {folder_path}") + + async with self.tmdb_client: + # Search for TV show + logger.debug(f"Searching TMDB for: {serie_name}") + search_results = await self.tmdb_client.search_tv_show(serie_name) + + if not search_results.get("results"): + raise TMDBAPIError(f"No results found for: {serie_name}") + + # Find best match (consider year if provided) + tv_show = self._find_best_match(search_results["results"], serie_name, year) + tv_id = tv_show["id"] + + logger.info(f"Found match: {tv_show['name']} (ID: {tv_id})") + + # Get detailed information + details = await self.tmdb_client.get_tv_show_details( + tv_id, + append_to_response="credits,external_ids,images" + ) + + # Convert TMDB data to TVShowNFO model + nfo_model = self._tmdb_to_nfo_model(details) + + # Generate XML + nfo_xml = generate_tvshow_nfo(nfo_model) + + # Save NFO file + nfo_path = folder_path / "tvshow.nfo" + nfo_path.write_text(nfo_xml, encoding="utf-8") + logger.info(f"Created NFO: {nfo_path}") + + # Download media files + await self._download_media_files( + details, + folder_path, + download_poster=download_poster, + download_logo=download_logo, + download_fanart=download_fanart + ) + + return nfo_path + + async def update_tvshow_nfo( + self, + serie_folder: str, + download_media: bool = True + ) -> Path: + """Update existing tvshow.nfo with fresh data from TMDB. + + Args: + serie_folder: Series folder name + download_media: Whether to re-download media files + + Returns: + Path to updated NFO file + + Raises: + FileNotFoundError: If NFO file doesn't exist + TMDBAPIError: If TMDB API fails + """ + nfo_path = self.anime_directory / serie_folder / "tvshow.nfo" + + if not nfo_path.exists(): + raise FileNotFoundError(f"NFO file not found: {nfo_path}") + + # Parse existing NFO to get TMDB ID + # For simplicity, we'll recreate from scratch + # In production, you'd parse the XML to extract the ID + + logger.info(f"Updating NFO for {serie_folder}") + # Implementation would extract serie name and call create_tvshow_nfo + # This is a simplified version + raise NotImplementedError("Update NFO not yet implemented") + + def _find_best_match( + self, + results: List[Dict[str, Any]], + query: str, + year: Optional[int] = None + ) -> Dict[str, Any]: + """Find best matching TV show from search results. + + Args: + results: TMDB search results + query: Original search query + year: Expected release year + + Returns: + Best matching TV show data + """ + if not results: + raise TMDBAPIError("No search results to match") + + # If year is provided, try to find exact match + if year: + for result in results: + first_air_date = result.get("first_air_date", "") + if first_air_date.startswith(str(year)): + logger.debug(f"Found year match: {result['name']} ({first_air_date})") + return result + + # Return first result (usually best match) + return results[0] + + def _tmdb_to_nfo_model(self, tmdb_data: Dict[str, Any]) -> TVShowNFO: + """Convert TMDB API data to TVShowNFO model. + + Args: + tmdb_data: TMDB TV show details + + Returns: + TVShowNFO Pydantic model + """ + # Extract basic info + title = tmdb_data["name"] + original_title = tmdb_data.get("original_name", title) + year = None + if tmdb_data.get("first_air_date"): + year = int(tmdb_data["first_air_date"][:4]) + + # Extract ratings + ratings = [] + if tmdb_data.get("vote_average"): + ratings.append(RatingInfo( + name="themoviedb", + value=float(tmdb_data["vote_average"]), + votes=tmdb_data.get("vote_count", 0), + max_rating=10, + default=True + )) + + # Extract external IDs + external_ids = tmdb_data.get("external_ids", {}) + imdb_id = external_ids.get("imdb_id") + tvdb_id = external_ids.get("tvdb_id") + + # Extract images + thumb_images = [] + fanart_images = [] + + # Poster + if tmdb_data.get("poster_path"): + poster_url = self.tmdb_client.get_image_url( + tmdb_data["poster_path"], + self.image_size + ) + thumb_images.append(ImageInfo(url=poster_url, aspect="poster")) + + # Backdrop/Fanart + if tmdb_data.get("backdrop_path"): + fanart_url = self.tmdb_client.get_image_url( + tmdb_data["backdrop_path"], + self.image_size + ) + fanart_images.append(ImageInfo(url=fanart_url)) + + # Logo from images if available + images_data = tmdb_data.get("images", {}) + logos = images_data.get("logos", []) + if logos: + logo_url = self.tmdb_client.get_image_url( + logos[0]["file_path"], + self.image_size + ) + thumb_images.append(ImageInfo(url=logo_url, aspect="clearlogo")) + + # Extract cast + actors = [] + credits = tmdb_data.get("credits", {}) + for cast_member in credits.get("cast", [])[:10]: # Top 10 actors + actor_thumb = None + if cast_member.get("profile_path"): + actor_thumb = self.tmdb_client.get_image_url( + cast_member["profile_path"], + "h632" + ) + + actors.append(ActorInfo( + name=cast_member["name"], + role=cast_member.get("character"), + thumb=actor_thumb, + tmdbid=cast_member["id"] + )) + + # Create unique IDs + unique_ids = [] + if tmdb_data.get("id"): + unique_ids.append(UniqueID( + type="tmdb", + value=str(tmdb_data["id"]), + default=False + )) + if imdb_id: + unique_ids.append(UniqueID( + type="imdb", + value=imdb_id, + default=False + )) + if tvdb_id: + unique_ids.append(UniqueID( + type="tvdb", + value=str(tvdb_id), + default=True + )) + + # Create NFO model + return TVShowNFO( + title=title, + originaltitle=original_title, + year=year, + plot=tmdb_data.get("overview"), + runtime=tmdb_data.get("episode_run_time", [None])[0] if tmdb_data.get("episode_run_time") else None, + premiered=tmdb_data.get("first_air_date"), + status=tmdb_data.get("status"), + genre=[g["name"] for g in tmdb_data.get("genres", [])], + studio=[n["name"] for n in tmdb_data.get("networks", [])], + country=[c["name"] for c in tmdb_data.get("production_countries", [])], + ratings=ratings, + tmdbid=tmdb_data.get("id"), + imdbid=imdb_id, + tvdbid=tvdb_id, + uniqueid=unique_ids, + thumb=thumb_images, + fanart=fanart_images, + actors=actors + ) + + async def _download_media_files( + self, + tmdb_data: Dict[str, Any], + folder_path: Path, + download_poster: bool = True, + download_logo: bool = True, + download_fanart: bool = True + ) -> Dict[str, bool]: + """Download media files (poster, logo, fanart). + + Args: + tmdb_data: TMDB TV show details + folder_path: Series folder path + download_poster: Download poster.jpg + download_logo: Download logo.png + download_fanart: Download fanart.jpg + + Returns: + Dictionary with download status for each file + """ + poster_url = None + logo_url = None + fanart_url = None + + # Get poster URL + if download_poster and tmdb_data.get("poster_path"): + poster_url = self.tmdb_client.get_image_url( + tmdb_data["poster_path"], + self.image_size + ) + + # Get fanart URL + if download_fanart and tmdb_data.get("backdrop_path"): + fanart_url = self.tmdb_client.get_image_url( + tmdb_data["backdrop_path"], + "original" # Always use original for fanart + ) + + # Get logo URL + if download_logo: + images_data = tmdb_data.get("images", {}) + logos = images_data.get("logos", []) + if logos: + logo_url = self.tmdb_client.get_image_url( + logos[0]["file_path"], + "original" # Logos should be original size + ) + + # Download all media concurrently + results = await self.image_downloader.download_all_media( + folder_path, + poster_url=poster_url, + logo_url=logo_url, + fanart_url=fanart_url, + skip_existing=True + ) + + logger.info(f"Media download results: {results}") + return results + + async def close(self): + """Clean up resources.""" + await self.tmdb_client.close() diff --git a/src/core/services/tmdb_client.py b/src/core/services/tmdb_client.py new file mode 100644 index 0000000..28c153b --- /dev/null +++ b/src/core/services/tmdb_client.py @@ -0,0 +1,283 @@ +"""TMDB API client for fetching TV show metadata. + +This module provides an async client for The Movie Database (TMDB) API, +adapted from the scraper project to fit the AniworldMain architecture. + +Example: + >>> async with TMDBClient(api_key="your_key") as client: + ... results = await client.search_tv_show("Attack on Titan") + ... show_id = results["results"][0]["id"] + ... details = await client.get_tv_show_details(show_id) +""" + +import asyncio +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +import aiohttp + +logger = logging.getLogger(__name__) + + +class TMDBAPIError(Exception): + """Exception raised for TMDB API errors.""" + pass + + +class TMDBClient: + """Async TMDB API client for TV show metadata. + + Attributes: + api_key: TMDB API key for authentication + base_url: Base URL for TMDB API + image_base_url: Base URL for TMDB images + max_connections: Maximum concurrent connections + session: aiohttp ClientSession for requests + """ + + DEFAULT_BASE_URL = "https://api.themoviedb.org/3" + DEFAULT_IMAGE_BASE_URL = "https://image.tmdb.org/t/p" + + def __init__( + self, + api_key: str, + base_url: str = DEFAULT_BASE_URL, + image_base_url: str = DEFAULT_IMAGE_BASE_URL, + max_connections: int = 10 + ): + """Initialize TMDB client. + + Args: + api_key: TMDB API key + base_url: TMDB API base URL + image_base_url: TMDB image base URL + max_connections: Maximum concurrent connections + """ + if not api_key: + raise ValueError("TMDB API key is required") + + self.api_key = api_key + self.base_url = base_url.rstrip('/') + self.image_base_url = image_base_url.rstrip('/') + self.max_connections = max_connections + self.session: Optional[aiohttp.ClientSession] = None + self._cache: Dict[str, Any] = {} + + async def __aenter__(self): + """Async context manager entry.""" + await self._ensure_session() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + + async def _ensure_session(self): + """Ensure aiohttp session is created.""" + if self.session is None or self.session.closed: + connector = aiohttp.TCPConnector(limit=self.max_connections) + self.session = aiohttp.ClientSession(connector=connector) + + async def _request( + self, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + max_retries: int = 3 + ) -> Dict[str, Any]: + """Make an async request to TMDB API with retries. + + Args: + endpoint: API endpoint (e.g., 'search/tv') + params: Query parameters + max_retries: Maximum retry attempts + + Returns: + API response as dictionary + + Raises: + TMDBAPIError: If request fails after retries + """ + await self._ensure_session() + + url = f"{self.base_url}/{endpoint}" + params = params or {} + params["api_key"] = self.api_key + + # Cache key for deduplication + cache_key = f"{endpoint}:{str(sorted(params.items()))}" + if cache_key in self._cache: + logger.debug(f"Cache hit for {endpoint}") + return self._cache[cache_key] + + delay = 1 + last_error = None + + for attempt in range(max_retries): + try: + logger.debug(f"TMDB API request: {endpoint} (attempt {attempt + 1})") + async with self.session.get(url, params=params, timeout=aiohttp.ClientTimeout(total=30)) as resp: + if resp.status == 401: + raise TMDBAPIError("Invalid TMDB API key") + elif resp.status == 404: + raise TMDBAPIError(f"Resource not found: {endpoint}") + elif resp.status == 429: + # Rate limit - wait longer + retry_after = int(resp.headers.get('Retry-After', delay * 2)) + logger.warning(f"Rate limited, waiting {retry_after}s") + await asyncio.sleep(retry_after) + continue + + resp.raise_for_status() + data = await resp.json() + self._cache[cache_key] = data + return data + + except aiohttp.ClientError as e: + last_error = e + if attempt < max_retries - 1: + logger.warning(f"Request failed (attempt {attempt + 1}): {e}, retrying in {delay}s") + await asyncio.sleep(delay) + delay *= 2 + else: + logger.error(f"Request failed after {max_retries} attempts: {e}") + + raise TMDBAPIError(f"Request failed after {max_retries} attempts: {last_error}") + + async def search_tv_show( + self, + query: str, + language: str = "de-DE", + page: int = 1 + ) -> Dict[str, Any]: + """Search for TV shows by name. + + Args: + query: Search query (show name) + language: Language for results (default: German) + page: Page number for pagination + + Returns: + Search results with list of shows + + Example: + >>> results = await client.search_tv_show("Attack on Titan") + >>> shows = results["results"] + """ + return await self._request( + "search/tv", + {"query": query, "language": language, "page": page} + ) + + async def get_tv_show_details( + self, + tv_id: int, + language: str = "de-DE", + append_to_response: Optional[str] = None + ) -> Dict[str, Any]: + """Get detailed information about a TV show. + + Args: + tv_id: TMDB TV show ID + language: Language for metadata + append_to_response: Additional data to include (e.g., "credits,images") + + Returns: + TV show details including metadata, cast, etc. + """ + params = {"language": language} + if append_to_response: + params["append_to_response"] = append_to_response + + return await self._request(f"tv/{tv_id}", params) + + async def get_tv_show_external_ids(self, tv_id: int) -> Dict[str, Any]: + """Get external IDs (IMDB, TVDB) for a TV show. + + Args: + tv_id: TMDB TV show ID + + Returns: + Dictionary with external IDs (imdb_id, tvdb_id, etc.) + """ + return await self._request(f"tv/{tv_id}/external_ids") + + async def get_tv_show_images( + self, + tv_id: int, + language: Optional[str] = None + ) -> Dict[str, Any]: + """Get images (posters, backdrops, logos) for a TV show. + + Args: + tv_id: TMDB TV show ID + language: Language filter for images (None = all languages) + + Returns: + Dictionary with poster, backdrop, and logo lists + """ + params = {} + if language: + params["language"] = language + + return await self._request(f"tv/{tv_id}/images", params) + + async def download_image( + self, + image_path: str, + local_path: Path, + size: str = "original" + ) -> None: + """Download an image from TMDB. + + Args: + image_path: Image path from TMDB API (e.g., "/abc123.jpg") + local_path: Local file path to save image + size: Image size (w500, original, etc.) + + Raises: + TMDBAPIError: If download fails + """ + await self._ensure_session() + + url = f"{self.image_base_url}/{size}{image_path}" + + try: + logger.debug(f"Downloading image from {url}") + async with self.session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as resp: + resp.raise_for_status() + + # Ensure parent directory exists + local_path.parent.mkdir(parents=True, exist_ok=True) + + # Write image data + with open(local_path, "wb") as f: + f.write(await resp.read()) + + logger.info(f"Downloaded image to {local_path}") + + except aiohttp.ClientError as e: + raise TMDBAPIError(f"Failed to download image: {e}") + + def get_image_url(self, image_path: str, size: str = "original") -> str: + """Get full URL for an image. + + Args: + image_path: Image path from TMDB API + size: Image size (w500, original, etc.) + + Returns: + Full image URL + """ + return f"{self.image_base_url}/{size}{image_path}" + + async def close(self): + """Close the aiohttp session and clean up resources.""" + if self.session and not self.session.closed: + await self.session.close() + logger.debug("TMDB client session closed") + + def clear_cache(self): + """Clear the request cache.""" + self._cache.clear() + logger.debug("TMDB client cache cleared") diff --git a/src/core/utils/image_downloader.py b/src/core/utils/image_downloader.py new file mode 100644 index 0000000..80f0c61 --- /dev/null +++ b/src/core/utils/image_downloader.py @@ -0,0 +1,295 @@ +"""Image downloader utility for NFO media files. + +This module provides functions to download poster, logo, and fanart images +from TMDB and validate them. + +Example: + >>> downloader = ImageDownloader() + >>> await downloader.download_poster(poster_url, "/path/to/poster.jpg") +""" + +import asyncio +import logging +from pathlib import Path +from typing import Optional + +import aiohttp +from PIL import Image + +logger = logging.getLogger(__name__) + + +class ImageDownloadError(Exception): + """Exception raised for image download failures.""" + pass + + +class ImageDownloader: + """Utility for downloading and validating images. + + Attributes: + max_retries: Maximum retry attempts for downloads + timeout: Request timeout in seconds + min_file_size: Minimum valid file size in bytes + """ + + def __init__( + self, + max_retries: int = 3, + timeout: int = 60, + min_file_size: int = 1024 # 1 KB + ): + """Initialize image downloader. + + Args: + max_retries: Maximum retry attempts + timeout: Request timeout in seconds + min_file_size: Minimum valid file size in bytes + """ + self.max_retries = max_retries + self.timeout = timeout + self.min_file_size = min_file_size + + async def download_image( + self, + url: str, + local_path: Path, + skip_existing: bool = True, + validate: bool = True + ) -> bool: + """Download an image from URL to local path. + + Args: + url: Image URL + local_path: Local file path to save image + skip_existing: Skip download if file already exists + validate: Validate image after download + + Returns: + True if download successful, False otherwise + + Raises: + ImageDownloadError: If download fails after retries + """ + # Check if file already exists + if skip_existing and local_path.exists(): + if local_path.stat().st_size >= self.min_file_size: + logger.debug(f"Image already exists: {local_path}") + return True + + # Ensure parent directory exists + local_path.parent.mkdir(parents=True, exist_ok=True) + + delay = 1 + last_error = None + + for attempt in range(self.max_retries): + try: + logger.debug(f"Downloading image from {url} (attempt {attempt + 1})") + + timeout = aiohttp.ClientTimeout(total=self.timeout) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as resp: + if resp.status == 404: + logger.warning(f"Image not found: {url}") + return False + + resp.raise_for_status() + + # Download image data + data = await resp.read() + + # Check file size + if len(data) < self.min_file_size: + raise ImageDownloadError( + f"Downloaded file too small: {len(data)} bytes" + ) + + # Write to file + with open(local_path, "wb") as f: + f.write(data) + + # Validate image if requested + if validate and not self.validate_image(local_path): + local_path.unlink(missing_ok=True) + raise ImageDownloadError("Image validation failed") + + logger.info(f"Downloaded image to {local_path}") + return True + + except (aiohttp.ClientError, IOError, ImageDownloadError) as e: + last_error = e + if attempt < self.max_retries - 1: + logger.warning( + f"Download failed (attempt {attempt + 1}): {e}, " + f"retrying in {delay}s" + ) + await asyncio.sleep(delay) + delay *= 2 + else: + logger.error( + f"Download failed after {self.max_retries} attempts: {e}" + ) + + raise ImageDownloadError( + f"Failed to download image after {self.max_retries} attempts: {last_error}" + ) + + async def download_poster( + self, + url: str, + series_folder: Path, + filename: str = "poster.jpg", + skip_existing: bool = True + ) -> bool: + """Download poster image. + + Args: + url: Poster URL + series_folder: Series folder path + filename: Output filename (default: poster.jpg) + skip_existing: Skip if file exists + + Returns: + True if successful + """ + local_path = series_folder / filename + try: + return await self.download_image(url, local_path, skip_existing) + except ImageDownloadError as e: + logger.warning(f"Failed to download poster: {e}") + return False + + async def download_logo( + self, + url: str, + series_folder: Path, + filename: str = "logo.png", + skip_existing: bool = True + ) -> bool: + """Download logo image. + + Args: + url: Logo URL + series_folder: Series folder path + filename: Output filename (default: logo.png) + skip_existing: Skip if file exists + + Returns: + True if successful + """ + local_path = series_folder / filename + try: + return await self.download_image(url, local_path, skip_existing) + except ImageDownloadError as e: + logger.warning(f"Failed to download logo: {e}") + return False + + async def download_fanart( + self, + url: str, + series_folder: Path, + filename: str = "fanart.jpg", + skip_existing: bool = True + ) -> bool: + """Download fanart/backdrop image. + + Args: + url: Fanart URL + series_folder: Series folder path + filename: Output filename (default: fanart.jpg) + skip_existing: Skip if file exists + + Returns: + True if successful + """ + local_path = series_folder / filename + try: + return await self.download_image(url, local_path, skip_existing) + except ImageDownloadError as e: + logger.warning(f"Failed to download fanart: {e}") + return False + + def validate_image(self, image_path: Path) -> bool: + """Validate that file is a valid image. + + Args: + image_path: Path to image file + + Returns: + True if valid image, False otherwise + """ + try: + with Image.open(image_path) as img: + # Verify it's a valid image + img.verify() + + # Check file size + if image_path.stat().st_size < self.min_file_size: + logger.warning(f"Image file too small: {image_path}") + return False + + return True + + except Exception as e: + logger.warning(f"Image validation failed for {image_path}: {e}") + return False + + async def download_all_media( + self, + series_folder: Path, + poster_url: Optional[str] = None, + logo_url: Optional[str] = None, + fanart_url: Optional[str] = None, + skip_existing: bool = True + ) -> dict[str, bool]: + """Download all media files (poster, logo, fanart). + + Args: + series_folder: Series folder path + poster_url: Poster URL (optional) + logo_url: Logo URL (optional) + fanart_url: Fanart URL (optional) + skip_existing: Skip existing files + + Returns: + Dictionary with download status for each file type + """ + results = { + "poster": False, + "logo": False, + "fanart": False + } + + tasks = [] + + if poster_url: + tasks.append(("poster", self.download_poster( + poster_url, series_folder, skip_existing=skip_existing + ))) + + if logo_url: + tasks.append(("logo", self.download_logo( + logo_url, series_folder, skip_existing=skip_existing + ))) + + if fanart_url: + tasks.append(("fanart", self.download_fanart( + fanart_url, series_folder, skip_existing=skip_existing + ))) + + # Download concurrently + if tasks: + task_results = await asyncio.gather( + *[task for _, task in tasks], + return_exceptions=True + ) + + for (media_type, _), result in zip(tasks, task_results): + if isinstance(result, Exception): + logger.error(f"Error downloading {media_type}: {result}") + results[media_type] = False + else: + results[media_type] = result + + return results diff --git a/src/core/utils/nfo_generator.py b/src/core/utils/nfo_generator.py new file mode 100644 index 0000000..2af8d42 --- /dev/null +++ b/src/core/utils/nfo_generator.py @@ -0,0 +1,192 @@ +"""NFO XML generator for Kodi/XBMC format. + +This module provides functions to generate tvshow.nfo XML files from +TVShowNFO Pydantic models, adapted from the scraper project. + +Example: + >>> from src.core.entities.nfo_models import TVShowNFO + >>> nfo = TVShowNFO(title="Test Show", year=2020, tmdbid=12345) + >>> xml_string = generate_tvshow_nfo(nfo) +""" + +import logging +from typing import Optional + +from lxml import etree + +from src.core.entities.nfo_models import TVShowNFO + +logger = logging.getLogger(__name__) + + +def generate_tvshow_nfo(tvshow: TVShowNFO, pretty_print: bool = True) -> str: + """Generate tvshow.nfo XML content from TVShowNFO model. + + Args: + tvshow: TVShowNFO Pydantic model with metadata + pretty_print: Whether to format XML with indentation + + Returns: + XML string in Kodi/XBMC tvshow.nfo format + + Example: + >>> nfo = TVShowNFO(title="Attack on Titan", year=2013) + >>> xml = generate_tvshow_nfo(nfo) + """ + root = etree.Element("tvshow") + + # Basic information + _add_element(root, "title", tvshow.title) + _add_element(root, "originaltitle", tvshow.originaltitle) + _add_element(root, "showtitle", tvshow.showtitle) + _add_element(root, "sorttitle", tvshow.sorttitle) + _add_element(root, "year", str(tvshow.year) if tvshow.year else None) + + # Plot and description + _add_element(root, "plot", tvshow.plot) + _add_element(root, "outline", tvshow.outline) + _add_element(root, "tagline", tvshow.tagline) + + # Technical details + _add_element(root, "runtime", str(tvshow.runtime) if tvshow.runtime else None) + _add_element(root, "mpaa", tvshow.mpaa) + _add_element(root, "certification", tvshow.certification) + + # Status and dates + _add_element(root, "premiered", tvshow.premiered) + _add_element(root, "status", tvshow.status) + _add_element(root, "dateadded", tvshow.dateadded) + + # Ratings + if tvshow.ratings: + ratings_elem = etree.SubElement(root, "ratings") + for rating in tvshow.ratings: + rating_elem = etree.SubElement(ratings_elem, "rating") + if rating.name: + rating_elem.set("name", rating.name) + if rating.max_rating: + rating_elem.set("max", str(rating.max_rating)) + if rating.default: + rating_elem.set("default", "true") + + _add_element(rating_elem, "value", str(rating.value)) + if rating.votes is not None: + _add_element(rating_elem, "votes", str(rating.votes)) + + _add_element(root, "userrating", str(tvshow.userrating) if tvshow.userrating is not None else None) + + # IDs + _add_element(root, "tmdbid", str(tvshow.tmdbid) if tvshow.tmdbid else None) + _add_element(root, "imdbid", tvshow.imdbid) + _add_element(root, "tvdbid", str(tvshow.tvdbid) if tvshow.tvdbid else None) + + # Legacy ID fields for compatibility + _add_element(root, "id", str(tvshow.tvdbid) if tvshow.tvdbid else None) + _add_element(root, "imdb_id", tvshow.imdbid) + + # Unique IDs + for uid in tvshow.uniqueid: + uid_elem = etree.SubElement(root, "uniqueid") + uid_elem.set("type", uid.type) + if uid.default: + uid_elem.set("default", "true") + uid_elem.text = uid.value + + # Multi-value fields + for genre in tvshow.genre: + _add_element(root, "genre", genre) + + for studio in tvshow.studio: + _add_element(root, "studio", studio) + + for country in tvshow.country: + _add_element(root, "country", country) + + for tag in tvshow.tag: + _add_element(root, "tag", tag) + + # Thumbnails (posters, logos) + for thumb in tvshow.thumb: + thumb_elem = etree.SubElement(root, "thumb") + if thumb.aspect: + thumb_elem.set("aspect", thumb.aspect) + if thumb.season is not None: + thumb_elem.set("season", str(thumb.season)) + if thumb.type: + thumb_elem.set("type", thumb.type) + thumb_elem.text = str(thumb.url) + + # Fanart + if tvshow.fanart: + fanart_elem = etree.SubElement(root, "fanart") + for fanart in tvshow.fanart: + fanart_thumb = etree.SubElement(fanart_elem, "thumb") + fanart_thumb.text = str(fanart.url) + + # Named seasons + for named_season in tvshow.namedseason: + season_elem = etree.SubElement(root, "namedseason") + season_elem.set("number", str(named_season.number)) + season_elem.text = named_season.name + + # Actors + for actor in tvshow.actors: + actor_elem = etree.SubElement(root, "actor") + _add_element(actor_elem, "name", actor.name) + _add_element(actor_elem, "role", actor.role) + _add_element(actor_elem, "thumb", str(actor.thumb) if actor.thumb else None) + _add_element(actor_elem, "profile", str(actor.profile) if actor.profile else None) + _add_element(actor_elem, "tmdbid", str(actor.tmdbid) if actor.tmdbid else None) + + # Additional fields + _add_element(root, "trailer", str(tvshow.trailer) if tvshow.trailer else None) + _add_element(root, "watched", "true" if tvshow.watched else "false") + if tvshow.playcount is not None: + _add_element(root, "playcount", str(tvshow.playcount)) + + # Generate XML string + xml_str = etree.tostring( + root, + pretty_print=pretty_print, + encoding="unicode", + xml_declaration=False + ) + + # Add XML declaration + xml_declaration = '\n' + return xml_declaration + xml_str + + +def _add_element(parent: etree.Element, tag: str, text: Optional[str]) -> Optional[etree.Element]: + """Add a child element to parent if text is not None or empty. + + Args: + parent: Parent XML element + tag: Tag name for child element + text: Text content (None or empty strings are skipped) + + Returns: + Created element or None if skipped + """ + if text is not None and text != "": + elem = etree.SubElement(parent, tag) + elem.text = text + return elem + return None + + +def validate_nfo_xml(xml_string: str) -> bool: + """Validate NFO XML structure. + + Args: + xml_string: XML content to validate + + Returns: + True if valid XML, False otherwise + """ + try: + etree.fromstring(xml_string.encode('utf-8')) + return True + except etree.XMLSyntaxError as e: + logger.error(f"Invalid NFO XML: {e}") + return False diff --git a/tests/unit/test_image_downloader.py b/tests/unit/test_image_downloader.py new file mode 100644 index 0000000..005ab96 --- /dev/null +++ b/tests/unit/test_image_downloader.py @@ -0,0 +1,411 @@ +"""Unit tests for image downloader.""" + +import io +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from PIL import Image + +from src.core.utils.image_downloader import ( + ImageDownloader, + ImageDownloadError, +) + + +@pytest.fixture +def image_downloader(): + """Create image downloader instance.""" + return ImageDownloader() + + +@pytest.fixture +def valid_image_bytes(): + """Create valid test image bytes.""" + img = Image.new('RGB', (100, 100), color='red') + buf = io.BytesIO() + img.save(buf, format='JPEG') + return buf.getvalue() + + +@pytest.fixture +def mock_session(): + """Create mock aiohttp session.""" + mock = AsyncMock() + mock.get = AsyncMock() + return mock + + +class TestImageDownloaderInit: + """Test ImageDownloader initialization.""" + + def test_init_default_values(self): + """Test initialization with default values.""" + downloader = ImageDownloader() + + assert downloader.min_file_size == 1024 + assert downloader.max_retries == 3 + assert downloader.retry_delay == 1.0 + assert downloader.timeout == 30 + assert downloader.session is None + + def test_init_custom_values(self): + """Test initialization with custom values.""" + downloader = ImageDownloader( + min_file_size=5000, + max_retries=5, + retry_delay=2.0, + timeout=60 + ) + + assert downloader.min_file_size == 5000 + assert downloader.max_retries == 5 + assert downloader.retry_delay == 2.0 + assert downloader.timeout == 60 + + +class TestImageDownloaderContextManager: + """Test ImageDownloader as context manager.""" + + @pytest.mark.asyncio + async def test_async_context_manager(self, image_downloader): + """Test async context manager creates session.""" + async with image_downloader as d: + assert d.session is not None + + assert image_downloader.session is None + + @pytest.mark.asyncio + async def test_close_closes_session(self, image_downloader): + """Test close method closes session.""" + await image_downloader.__aenter__() + + assert image_downloader.session is not None + await image_downloader.close() + assert image_downloader.session is None + + +class TestImageDownloaderValidateImage: + """Test _validate_image method.""" + + def test_validate_valid_image(self, image_downloader, valid_image_bytes): + """Test validation of valid image.""" + # Should not raise exception + image_downloader._validate_image(valid_image_bytes) + + def test_validate_too_small(self, image_downloader): + """Test validation rejects too-small file.""" + tiny_data = b"tiny" + + with pytest.raises(ImageDownloadError, match="too small"): + image_downloader._validate_image(tiny_data) + + def test_validate_invalid_image_data(self, image_downloader): + """Test validation rejects invalid image data.""" + invalid_data = b"x" * 2000 # Large enough but not an image + + with pytest.raises(ImageDownloadError, match="Cannot open"): + image_downloader._validate_image(invalid_data) + + def test_validate_corrupted_image(self, image_downloader): + """Test validation rejects corrupted image.""" + # Create a corrupted JPEG-like file + corrupted = b"\xFF\xD8\xFF\xE0" + b"corrupted_data" * 100 + + with pytest.raises(ImageDownloadError): + image_downloader._validate_image(corrupted) + + +class TestImageDownloaderDownloadImage: + """Test download_image method.""" + + @pytest.mark.asyncio + async def test_download_image_success( + self, + image_downloader, + valid_image_bytes, + tmp_path + ): + """Test successful image download.""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.read = AsyncMock(return_value=valid_image_bytes) + mock_session.get = AsyncMock(return_value=mock_response) + + image_downloader.session = mock_session + + output_path = tmp_path / "test.jpg" + await image_downloader.download_image("https://test.com/image.jpg", output_path) + + assert output_path.exists() + assert output_path.stat().st_size == len(valid_image_bytes) + + @pytest.mark.asyncio + async def test_download_image_skip_existing( + self, + image_downloader, + tmp_path + ): + """Test skipping existing file.""" + output_path = tmp_path / "existing.jpg" + output_path.write_bytes(b"existing") + + mock_session = AsyncMock() + image_downloader.session = mock_session + + result = await image_downloader.download_image( + "https://test.com/image.jpg", + output_path, + skip_existing=True + ) + + assert result is True + assert output_path.read_bytes() == b"existing" # Unchanged + assert not mock_session.get.called + + @pytest.mark.asyncio + async def test_download_image_overwrite_existing( + self, + image_downloader, + valid_image_bytes, + tmp_path + ): + """Test overwriting existing file.""" + output_path = tmp_path / "existing.jpg" + output_path.write_bytes(b"old") + + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.read = AsyncMock(return_value=valid_image_bytes) + mock_session.get = AsyncMock(return_value=mock_response) + + image_downloader.session = mock_session + + await image_downloader.download_image( + "https://test.com/image.jpg", + output_path, + skip_existing=False + ) + + assert output_path.exists() + assert output_path.read_bytes() == valid_image_bytes + + @pytest.mark.asyncio + async def test_download_image_invalid_url(self, image_downloader, tmp_path): + """Test download with invalid URL.""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 404 + mock_response.raise_for_status = MagicMock(side_effect=Exception("Not Found")) + mock_session.get = AsyncMock(return_value=mock_response) + + image_downloader.session = mock_session + + output_path = tmp_path / "test.jpg" + + with pytest.raises(ImageDownloadError): + await image_downloader.download_image("https://test.com/missing.jpg", output_path) + + +class TestImageDownloaderSpecificMethods: + """Test type-specific download methods.""" + + @pytest.mark.asyncio + async def test_download_poster(self, image_downloader, valid_image_bytes, tmp_path): + """Test download_poster method.""" + with patch.object( + image_downloader, + 'download_image', + new_callable=AsyncMock + ) as mock_download: + await image_downloader.download_poster( + "https://test.com/poster.jpg", + tmp_path + ) + + mock_download.assert_called_once() + call_args = mock_download.call_args + assert call_args[0][1] == tmp_path / "poster.jpg" + + @pytest.mark.asyncio + async def test_download_logo(self, image_downloader, tmp_path): + """Test download_logo method.""" + with patch.object( + image_downloader, + 'download_image', + new_callable=AsyncMock + ) as mock_download: + await image_downloader.download_logo( + "https://test.com/logo.png", + tmp_path + ) + + mock_download.assert_called_once() + call_args = mock_download.call_args + assert call_args[0][1] == tmp_path / "logo.png" + + @pytest.mark.asyncio + async def test_download_fanart(self, image_downloader, tmp_path): + """Test download_fanart method.""" + with patch.object( + image_downloader, + 'download_image', + new_callable=AsyncMock + ) as mock_download: + await image_downloader.download_fanart( + "https://test.com/fanart.jpg", + tmp_path + ) + + mock_download.assert_called_once() + call_args = mock_download.call_args + assert call_args[0][1] == tmp_path / "fanart.jpg" + + +class TestImageDownloaderDownloadAll: + """Test download_all_media method.""" + + @pytest.mark.asyncio + async def test_download_all_success(self, image_downloader, tmp_path): + """Test downloading all media types.""" + with patch.object( + image_downloader, + 'download_poster', + new_callable=AsyncMock, + return_value=True + ), patch.object( + image_downloader, + 'download_logo', + new_callable=AsyncMock, + return_value=True + ), patch.object( + image_downloader, + 'download_fanart', + new_callable=AsyncMock, + return_value=True + ): + results = await image_downloader.download_all_media( + tmp_path, + poster_url="https://test.com/poster.jpg", + logo_url="https://test.com/logo.png", + fanart_url="https://test.com/fanart.jpg" + ) + + assert results["poster"] is True + assert results["logo"] is True + assert results["fanart"] is True + + @pytest.mark.asyncio + async def test_download_all_partial(self, image_downloader, tmp_path): + """Test downloading with some URLs missing.""" + with patch.object( + image_downloader, + 'download_poster', + new_callable=AsyncMock, + return_value=True + ), patch.object( + image_downloader, + 'download_logo', + new_callable=AsyncMock + ) as mock_logo, patch.object( + image_downloader, + 'download_fanart', + new_callable=AsyncMock + ) as mock_fanart: + results = await image_downloader.download_all_media( + tmp_path, + poster_url="https://test.com/poster.jpg", + logo_url=None, + fanart_url=None + ) + + assert results["poster"] is True + assert results["logo"] is None + assert results["fanart"] is None + assert not mock_logo.called + assert not mock_fanart.called + + @pytest.mark.asyncio + async def test_download_all_with_failures(self, image_downloader, tmp_path): + """Test downloading with some failures.""" + with patch.object( + image_downloader, + 'download_poster', + new_callable=AsyncMock, + return_value=True + ), patch.object( + image_downloader, + 'download_logo', + new_callable=AsyncMock, + side_effect=ImageDownloadError("Failed") + ), patch.object( + image_downloader, + 'download_fanart', + new_callable=AsyncMock, + return_value=True + ): + results = await image_downloader.download_all_media( + tmp_path, + poster_url="https://test.com/poster.jpg", + logo_url="https://test.com/logo.png", + fanart_url="https://test.com/fanart.jpg" + ) + + assert results["poster"] is True + assert results["logo"] is False + assert results["fanart"] is True + + +class TestImageDownloaderRetryLogic: + """Test retry logic.""" + + @pytest.mark.asyncio + async def test_retry_on_failure(self, image_downloader, valid_image_bytes, tmp_path): + """Test retry logic on temporary failure.""" + mock_session = AsyncMock() + + # First two calls fail, third succeeds + mock_response_fail = AsyncMock() + mock_response_fail.status = 500 + mock_response_fail.raise_for_status = MagicMock(side_effect=Exception("Server Error")) + + mock_response_success = AsyncMock() + mock_response_success.status = 200 + mock_response_success.read = AsyncMock(return_value=valid_image_bytes) + + mock_session.get = AsyncMock( + side_effect=[mock_response_fail, mock_response_fail, mock_response_success] + ) + + image_downloader.session = mock_session + image_downloader.retry_delay = 0.1 # Speed up test + + output_path = tmp_path / "test.jpg" + await image_downloader.download_image("https://test.com/image.jpg", output_path) + + # Should have retried twice then succeeded + assert mock_session.get.call_count == 3 + assert output_path.exists() + + @pytest.mark.asyncio + async def test_max_retries_exceeded(self, image_downloader, tmp_path): + """Test failure after max retries.""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 500 + mock_response.raise_for_status = MagicMock(side_effect=Exception("Server Error")) + mock_session.get = AsyncMock(return_value=mock_response) + + image_downloader.session = mock_session + image_downloader.max_retries = 2 + image_downloader.retry_delay = 0.1 + + output_path = tmp_path / "test.jpg" + + with pytest.raises(ImageDownloadError): + await image_downloader.download_image("https://test.com/image.jpg", output_path) + + # Should have tried 3 times (initial + 2 retries) + assert mock_session.get.call_count == 3 diff --git a/tests/unit/test_nfo_generator.py b/tests/unit/test_nfo_generator.py new file mode 100644 index 0000000..81155d2 --- /dev/null +++ b/tests/unit/test_nfo_generator.py @@ -0,0 +1,325 @@ +"""Unit tests for NFO generator.""" + +import pytest +from lxml import etree + +from src.core.entities.nfo_models import ( + ActorInfo, + ImageInfo, + RatingInfo, + TVShowNFO, + UniqueID, +) +from src.core.utils.nfo_generator import generate_tvshow_nfo, validate_nfo_xml + + +class TestGenerateTVShowNFO: + """Test generate_tvshow_nfo function.""" + + def test_generate_minimal_nfo(self): + """Test generation with minimal required fields.""" + nfo = TVShowNFO( + title="Test Show", + plot="A test show" + ) + + xml_string = generate_tvshow_nfo(nfo) + + assert xml_string.startswith('') + assert "Test Show" in xml_string + assert "A test show" in xml_string + + def test_generate_complete_nfo(self): + """Test generation with all fields populated.""" + nfo = TVShowNFO( + title="Complete Show", + originaltitle="Original Title", + year=2020, + plot="Complete test", + runtime=45, + premiered="2020-01-15", + status="Continuing", + genre=["Action", "Drama"], + studio=["Studio 1"], + country=["USA"], + ratings=[RatingInfo( + name="themoviedb", + value=8.5, + votes=1000, + max_rating=10, + default=True + )], + actors=[ActorInfo( + name="Test Actor", + role="Main Character" + )], + thumb=[ImageInfo(url="https://test.com/poster.jpg")], + uniqueid=[UniqueID(type="tmdb", value="12345")] + ) + + xml_string = generate_tvshow_nfo(nfo) + + # Verify all elements present + assert "Complete Show" in xml_string + assert "Original Title" in xml_string + assert "2020" in xml_string + assert "45" in xml_string + assert "2020-01-15" in xml_string + assert "Continuing" in xml_string + assert "Action" in xml_string + assert "Drama" in xml_string + assert "Studio 1" in xml_string + assert "USA" in xml_string + assert "Test Actor" in xml_string + assert "Main Character" in xml_string + + def test_generate_nfo_with_ratings(self): + """Test NFO with multiple ratings.""" + nfo = TVShowNFO( + title="Rated Show", + plot="Test", + ratings=[ + RatingInfo( + name="themoviedb", + value=8.5, + votes=1000, + max_rating=10, + default=True + ), + RatingInfo( + name="imdb", + value=8.2, + votes=5000, + max_rating=10, + default=False + ) + ] + ) + + xml_string = generate_tvshow_nfo(nfo) + + assert '' in xml_string + assert '' in xml_string + assert '8.5' in xml_string + assert '1000' in xml_string + assert '' in xml_string + + def test_generate_nfo_with_actors(self): + """Test NFO with multiple actors.""" + nfo = TVShowNFO( + title="Cast Show", + plot="Test", + actors=[ + ActorInfo(name="Actor 1", role="Hero"), + ActorInfo(name="Actor 2", role="Villain", thumb="https://test.com/actor2.jpg") + ] + ) + + xml_string = generate_tvshow_nfo(nfo) + + assert '' in xml_string + assert 'Actor 1' in xml_string + assert 'Hero' in xml_string + assert 'Actor 2' in xml_string + assert 'https://test.com/actor2.jpg' in xml_string + + def test_generate_nfo_with_images(self): + """Test NFO with various image types.""" + nfo = TVShowNFO( + title="Image Show", + plot="Test", + thumb=[ + ImageInfo(url="https://test.com/poster.jpg", aspect="poster"), + ImageInfo(url="https://test.com/logo.png", aspect="clearlogo") + ], + fanart=[ + ImageInfo(url="https://test.com/fanart.jpg") + ] + ) + + xml_string = generate_tvshow_nfo(nfo) + + assert 'https://test.com/poster.jpg' in xml_string + assert 'https://test.com/logo.png' in xml_string + assert '' in xml_string + assert 'https://test.com/fanart.jpg' in xml_string + + def test_generate_nfo_with_unique_ids(self): + """Test NFO with multiple unique IDs.""" + nfo = TVShowNFO( + title="ID Show", + plot="Test", + uniqueid=[ + UniqueID(type="tmdb", value="12345", default=False), + UniqueID(type="tvdb", value="67890", default=True), + UniqueID(type="imdb", value="tt1234567", default=False) + ] + ) + + xml_string = generate_tvshow_nfo(nfo) + + assert '12345' in xml_string + assert '67890' in xml_string + assert 'tt1234567' in xml_string + + def test_generate_nfo_escapes_special_chars(self): + """Test that special XML characters are escaped.""" + nfo = TVShowNFO( + title="Show & special \"chars\"", + plot="Plot with & ampersand" + ) + + xml_string = generate_tvshow_nfo(nfo) + + # XML should escape special characters + assert "<" in xml_string or "" in xml_string + assert "&" in xml_string or "&" in xml_string + + def test_generate_nfo_valid_xml(self): + """Test that generated XML is valid.""" + nfo = TVShowNFO( + title="Valid Show", + plot="Test", + year=2020, + genre=["Action"], + ratings=[RatingInfo(name="test", value=8.0)] + ) + + xml_string = generate_tvshow_nfo(nfo) + + # Should be parseable as XML + root = etree.fromstring(xml_string.encode('utf-8')) + assert root.tag == "tvshow" + + def test_generate_nfo_none_values_omitted(self): + """Test that None values are omitted from XML.""" + nfo = TVShowNFO( + title="Sparse Show", + plot="Test", + year=None, + runtime=None, + premiered=None + ) + + xml_string = generate_tvshow_nfo(nfo) + + # None values should not appear in XML + assert "<year>" not in xml_string + assert "<runtime>" not in xml_string + assert "<premiered>" not in xml_string + + +class TestValidateNFOXML: + """Test validate_nfo_xml function.""" + + def test_validate_valid_xml(self): + """Test validation of valid XML.""" + nfo = TVShowNFO(title="Test", plot="Test") + xml_string = generate_tvshow_nfo(nfo) + + # Should not raise exception + validate_nfo_xml(xml_string) + + def test_validate_invalid_xml(self): + """Test validation of invalid XML.""" + invalid_xml = "<?xml version='1.0'?><tvshow><title>Unclosed" + + with pytest.raises(ValueError, match="Invalid XML"): + validate_nfo_xml(invalid_xml) + + def test_validate_missing_tvshow_root(self): + """Test validation rejects non-tvshow root.""" + invalid_xml = '<?xml version="1.0"?><movie><title>Test' + + with pytest.raises(ValueError, match="root element must be"): + validate_nfo_xml(invalid_xml) + + def test_validate_empty_string(self): + """Test validation rejects empty string.""" + with pytest.raises(ValueError): + validate_nfo_xml("") + + def test_validate_well_formed_structure(self): + """Test validation accepts well-formed structure.""" + xml = """ + + Test Show + Test plot + 2020 + + """ + + validate_nfo_xml(xml) + + +class TestNFOGeneratorEdgeCases: + """Test edge cases in NFO generation.""" + + def test_empty_lists(self): + """Test generation with empty lists.""" + nfo = TVShowNFO( + title="Empty Lists", + plot="Test", + genre=[], + studio=[], + actors=[] + ) + + xml_string = generate_tvshow_nfo(nfo) + + # Should generate valid XML even with empty lists + root = etree.fromstring(xml_string.encode('utf-8')) + assert root.tag == "tvshow" + + def test_unicode_characters(self): + """Test handling of Unicode characters.""" + nfo = TVShowNFO( + title="アニメ Show 中文", + plot="Plot with émojis 🎬 and spëcial çhars" + ) + + xml_string = generate_tvshow_nfo(nfo) + + # Should encode Unicode properly + assert "アニメ" in xml_string + assert "中文" in xml_string + assert "émojis" in xml_string + + def test_very_long_plot(self): + """Test handling of very long plot text.""" + long_plot = "A" * 10000 + nfo = TVShowNFO( + title="Long Plot", + plot=long_plot + ) + + xml_string = generate_tvshow_nfo(nfo) + + assert long_plot in xml_string + + def test_multiple_studios(self): + """Test handling multiple studios.""" + nfo = TVShowNFO( + title="Multi Studio", + plot="Test", + studio=["Studio A", "Studio B", "Studio C"] + ) + + xml_string = generate_tvshow_nfo(nfo) + + assert xml_string.count("") == 3 + assert "Studio A" in xml_string + assert "Studio B" in xml_string + assert "Studio C" in xml_string + + def test_special_date_formats(self): + """Test various date format inputs.""" + nfo = TVShowNFO( + title="Date Test", + plot="Test", + premiered="2020-01-01" + ) + + xml_string = generate_tvshow_nfo(nfo) + + assert "2020-01-01" in xml_string diff --git a/tests/unit/test_tmdb_client.py b/tests/unit/test_tmdb_client.py new file mode 100644 index 0000000..a725235 --- /dev/null +++ b/tests/unit/test_tmdb_client.py @@ -0,0 +1,331 @@ +"""Unit tests for TMDB client.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from aiohttp import ClientResponseError, ClientSession + +from src.core.services.tmdb_client import TMDBAPIError, TMDBClient + + +@pytest.fixture +def tmdb_client(): + """Create TMDB client with test API key.""" + return TMDBClient(api_key="test_api_key") + + +@pytest.fixture +def mock_response(): + """Create mock aiohttp response.""" + mock = AsyncMock() + mock.status = 200 + mock.json = AsyncMock(return_value={"success": True}) + return mock + + +class TestTMDBClientInit: + """Test TMDB client initialization.""" + + def test_init_with_api_key(self): + """Test initialization with API key.""" + client = TMDBClient(api_key="my_key") + assert client.api_key == "my_key" + assert client.base_url == "https://api.themoviedb.org/3" + assert client.image_base_url == "https://image.tmdb.org/t/p" + assert client.session is None + assert client._cache == {} + + def test_init_sets_attributes(self): + """Test all attributes are set correctly.""" + client = TMDBClient(api_key="test") + assert hasattr(client, "api_key") + assert hasattr(client, "base_url") + assert hasattr(client, "image_base_url") + assert hasattr(client, "session") + assert hasattr(client, "_cache") + + +class TestTMDBClientContextManager: + """Test TMDB client as context manager.""" + + @pytest.mark.asyncio + async def test_async_context_manager(self): + """Test async context manager creates session.""" + client = TMDBClient(api_key="test") + + async with client as c: + assert c.session is not None + assert isinstance(c.session, ClientSession) + + # Session should be closed after context + assert client.session is None + + @pytest.mark.asyncio + async def test_close_closes_session(self): + """Test close method closes session.""" + client = TMDBClient(api_key="test") + await client.__aenter__() + + assert client.session is not None + await client.close() + assert client.session is None + + +class TestTMDBClientSearchTVShow: + """Test search_tv_show method.""" + + @pytest.mark.asyncio + async def test_search_tv_show_success(self, tmdb_client, mock_response): + """Test successful TV show search.""" + mock_response.json = AsyncMock(return_value={ + "results": [ + {"id": 1, "name": "Test Show"}, + {"id": 2, "name": "Another Show"} + ] + }) + + with patch.object(tmdb_client, "_make_request", return_value=mock_response.json.return_value): + result = await tmdb_client.search_tv_show("Test Show") + + assert "results" in result + assert len(result["results"]) == 2 + assert result["results"][0]["name"] == "Test Show" + + @pytest.mark.asyncio + async def test_search_tv_show_with_year(self, tmdb_client): + """Test TV show search with year filter.""" + mock_data = {"results": [{"id": 1, "name": "Test Show", "first_air_date": "2020-01-01"}]} + + with patch.object(tmdb_client, "_make_request", return_value=mock_data): + result = await tmdb_client.search_tv_show("Test Show", year=2020) + + assert "results" in result + + @pytest.mark.asyncio + async def test_search_tv_show_empty_results(self, tmdb_client): + """Test search with no results.""" + with patch.object(tmdb_client, "_make_request", return_value={"results": []}): + result = await tmdb_client.search_tv_show("NonexistentShow") + + assert result["results"] == [] + + @pytest.mark.asyncio + async def test_search_tv_show_uses_cache(self, tmdb_client): + """Test search results are cached.""" + mock_data = {"results": [{"id": 1, "name": "Cached Show"}]} + + with patch.object(tmdb_client, "_make_request", return_value=mock_data) as mock_request: + # First call should hit API + result1 = await tmdb_client.search_tv_show("Cached Show") + assert mock_request.call_count == 1 + + # Second call should use cache + result2 = await tmdb_client.search_tv_show("Cached Show") + assert mock_request.call_count == 1 # Not called again + + assert result1 == result2 + + +class TestTMDBClientGetTVShowDetails: + """Test get_tv_show_details method.""" + + @pytest.mark.asyncio + async def test_get_tv_show_details_success(self, tmdb_client): + """Test successful TV show details retrieval.""" + mock_data = { + "id": 123, + "name": "Test Show", + "overview": "A test show", + "first_air_date": "2020-01-01" + } + + with patch.object(tmdb_client, "_make_request", return_value=mock_data): + result = await tmdb_client.get_tv_show_details(123) + + assert result["id"] == 123 + assert result["name"] == "Test Show" + + @pytest.mark.asyncio + async def test_get_tv_show_details_with_append(self, tmdb_client): + """Test details with append_to_response.""" + mock_data = { + "id": 123, + "name": "Test Show", + "credits": {"cast": []}, + "images": {"posters": []} + } + + with patch.object(tmdb_client, "_make_request", return_value=mock_data) as mock_request: + result = await tmdb_client.get_tv_show_details(123, append_to_response="credits,images") + + assert "credits" in result + assert "images" in result + + # Verify append_to_response was passed + call_args = mock_request.call_args + assert "credits,images" in str(call_args) + + +class TestTMDBClientGetExternalIDs: + """Test get_tv_show_external_ids method.""" + + @pytest.mark.asyncio + async def test_get_external_ids_success(self, tmdb_client): + """Test successful external IDs retrieval.""" + mock_data = { + "imdb_id": "tt1234567", + "tvdb_id": 98765 + } + + with patch.object(tmdb_client, "_make_request", return_value=mock_data): + result = await tmdb_client.get_tv_show_external_ids(123) + + assert result["imdb_id"] == "tt1234567" + assert result["tvdb_id"] == 98765 + + +class TestTMDBClientGetImages: + """Test get_tv_show_images method.""" + + @pytest.mark.asyncio + async def test_get_images_success(self, tmdb_client): + """Test successful images retrieval.""" + mock_data = { + "posters": [{"file_path": "/poster.jpg"}], + "backdrops": [{"file_path": "/backdrop.jpg"}], + "logos": [{"file_path": "/logo.png"}] + } + + with patch.object(tmdb_client, "_make_request", return_value=mock_data): + result = await tmdb_client.get_tv_show_images(123) + + assert "posters" in result + assert "backdrops" in result + assert "logos" in result + assert len(result["posters"]) == 1 + + +class TestTMDBClientImageURL: + """Test get_image_url method.""" + + def test_get_image_url_with_size(self, tmdb_client): + """Test image URL generation with size.""" + url = tmdb_client.get_image_url("/test.jpg", "w500") + assert url == "https://image.tmdb.org/t/p/w500/test.jpg" + + def test_get_image_url_original(self, tmdb_client): + """Test image URL with original size.""" + url = tmdb_client.get_image_url("/test.jpg", "original") + assert url == "https://image.tmdb.org/t/p/original/test.jpg" + + def test_get_image_url_strips_leading_slash(self, tmdb_client): + """Test path without leading slash works.""" + url = tmdb_client.get_image_url("test.jpg", "w500") + assert url == "https://image.tmdb.org/t/p/w500/test.jpg" + + +class TestTMDBClientMakeRequest: + """Test _make_request private method.""" + + @pytest.mark.asyncio + async def test_make_request_success(self, tmdb_client): + """Test successful request.""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"data": "test"}) + mock_session.get = AsyncMock(return_value=mock_response) + + tmdb_client.session = mock_session + + result = await tmdb_client._make_request("tv/search", {"query": "test"}) + + assert result == {"data": "test"} + + @pytest.mark.asyncio + async def test_make_request_unauthorized(self, tmdb_client): + """Test 401 unauthorized error.""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 401 + mock_response.raise_for_status = MagicMock( + side_effect=ClientResponseError(None, None, status=401) + ) + mock_session.get = AsyncMock(return_value=mock_response) + + tmdb_client.session = mock_session + + with pytest.raises(TMDBAPIError, match="Invalid API key"): + await tmdb_client._make_request("tv/search", {}) + + @pytest.mark.asyncio + async def test_make_request_not_found(self, tmdb_client): + """Test 404 not found error.""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 404 + mock_response.raise_for_status = MagicMock( + side_effect=ClientResponseError(None, None, status=404) + ) + mock_session.get = AsyncMock(return_value=mock_response) + + tmdb_client.session = mock_session + + with pytest.raises(TMDBAPIError, match="not found"): + await tmdb_client._make_request("tv/99999", {}) + + @pytest.mark.asyncio + async def test_make_request_rate_limit(self, tmdb_client): + """Test 429 rate limit error.""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 429 + mock_response.raise_for_status = MagicMock( + side_effect=ClientResponseError(None, None, status=429) + ) + mock_session.get = AsyncMock(return_value=mock_response) + + tmdb_client.session = mock_session + + with pytest.raises(TMDBAPIError, match="rate limit"): + await tmdb_client._make_request("tv/search", {}) + + +class TestTMDBClientDownloadImage: + """Test download_image method.""" + + @pytest.mark.asyncio + async def test_download_image_success(self, tmdb_client, tmp_path): + """Test successful image download.""" + image_data = b"fake_image_data" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.read = AsyncMock(return_value=image_data) + mock_session.get = AsyncMock(return_value=mock_response) + + tmdb_client.session = mock_session + + output_path = tmp_path / "test.jpg" + await tmdb_client.download_image("https://test.com/image.jpg", output_path) + + assert output_path.exists() + assert output_path.read_bytes() == image_data + + @pytest.mark.asyncio + async def test_download_image_failure(self, tmdb_client, tmp_path): + """Test image download failure.""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 404 + mock_response.raise_for_status = MagicMock( + side_effect=ClientResponseError(None, None, status=404) + ) + mock_session.get = AsyncMock(return_value=mock_response) + + tmdb_client.session = mock_session + + output_path = tmp_path / "test.jpg" + + with pytest.raises(TMDBAPIError): + await tmdb_client.download_image("https://test.com/missing.jpg", output_path)