diff --git a/storage_service/config/config_virus_checker.py b/storage_service/config/config_virus_checker.py index 45f5f45..7f1feed 100644 --- a/storage_service/config/config_virus_checker.py +++ b/storage_service/config/config_virus_checker.py @@ -3,7 +3,12 @@ from dotenv import load_dotenv import os -def get_virus_checker_api_key(): +def get_virus_checker_api_key() -> str: load_dotenv() - return {"api_key": os.environ.get("VIRUS_CHECKER_API_KEY")} + api_key = os.environ.get("VIRUS_CHECKER_API_KEY") + + if not api_key: + raise RuntimeError("Virus Checker API Key not found") + + return api_key diff --git a/storage_service/controller/health_checker_controller.py b/storage_service/controller/health_checker_controller.py index 390225a..e5067f4 100644 --- a/storage_service/controller/health_checker_controller.py +++ b/storage_service/controller/health_checker_controller.py @@ -1,11 +1,15 @@ +from storage_service.model.health_check.health_check_response import ( + HealthCheckResponse, +) + from fastapi import APIRouter from fastapi_utils.cbv import cbv -health_router = APIRouter() +health_router = APIRouter(tags=["health"]) @cbv(health_router) class HealthCheckerController: @health_router.get("/health", status_code=200) - def health(self) -> dict[str, str]: - return {"status": "healthy"} + def health(self) -> HealthCheckResponse: + return HealthCheckResponse(status="healthy") diff --git a/storage_service/controller/storage_controller.py b/storage_service/controller/storage_controller.py index 8e35604..ff759db 100644 --- a/storage_service/controller/storage_controller.py +++ b/storage_service/controller/storage_controller.py @@ -4,18 +4,20 @@ from storage_service.depends.depend_queue import dependency_queue from storage_service.depends.depend_s3_service import ( dependency_storage_service, ) +from storage_service.model.storage.new_file_request import NewFileURLRequest +from storage_service.model.storage.process_file_request import ( + ProcessFileRequest, +) +from storage_service.model.storage.signed_url_response import SignedUrlResponse from storage_service.service.storage.storage_service import StorageService -from storage_service.utils.enums.file_type import FileType from storage_service.utils.file_name_hash import file_name_hash from storage_service.worker.storage_file_worker import storage_file_worker -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Depends from fastapi_utils.cbv import cbv from rq import Queue -from typing import Annotated - -s3_router = APIRouter() +s3_router = APIRouter(tags=["storage"]) @cbv(s3_router) @@ -25,31 +27,30 @@ class StorageController: dependency_storage_service, use_cache=True ) - @s3_router.post("/file/", status_code=200) - def new_file_url( - self, - username: Annotated[str, Body(embed=True)], - file_postfix: Annotated[str, Body(embed=True)], - file_type: Annotated[FileType, Body(embed=True)], - ) -> dict[str, str]: + @s3_router.post("/file", status_code=200) + def new_file_url(self, new_file_request: NewFileURLRequest) -> SignedUrlResponse: + hashed_file_name = file_name_hash( + new_file_request.file_key, new_file_request.file_postfix + ) + return self.storage_service.get_temp_upload_link( - file_name_hash(username, file_postfix), file_type + hashed_file_name, new_file_request.file_type ) - @s3_router.get("/file/", status_code=200) - def file_url(self, username: str, file_postfix: str) -> dict[str, str | None]: + @s3_router.get("/file", status_code=200) + def file_url(self, file_key: str, file_postfix: str) -> SignedUrlResponse: return self.storage_service.get_temp_read_link( - file_name_hash(username, file_postfix) + file_name_hash(file_key, file_postfix) ) - @s3_router.delete("/file/", status_code=204) - def delete_file(self, username: str, file_postfix: str): - return self.storage_service.delete_file(file_name_hash(username, file_postfix)) + @s3_router.delete("/file", status_code=204) + def delete_file(self, file_key: str, file_postfix: str): + return self.storage_service.delete_file(file_name_hash(file_key, file_postfix)) @s3_router.post("/file/process", status_code=200) - def process_file( - self, - username: Annotated[str, Body(embed=True)], - file_postfix: Annotated[str, Body(embed=True)], - ): - self.queue.enqueue(storage_file_worker, username, file_postfix) + def process_file(self, process_file_request: ProcessFileRequest): + self.queue.enqueue( + storage_file_worker, + process_file_request.file_key, + process_file_request.file_postfix, + ) diff --git a/storage_service/depends/depend_s3_service.py b/storage_service/depends/depend_s3_service.py index ae1834e..d45da2f 100644 --- a/storage_service/depends/depend_s3_service.py +++ b/storage_service/depends/depend_s3_service.py @@ -3,6 +3,8 @@ from storage_service.service.storage.amazon_s3_service import AmazonS3Service from storage_service.service.storage.storage_service import StorageService from storage_service.utils.enums.storage_type import StorageType +import boto3 +import botocore.client from dotenv import load_dotenv import os @@ -14,6 +16,27 @@ def dependency_storage_service() -> StorageService: load_dotenv() if StorageType(os.environ["STORAGE_TYPE"]) == StorageType.S3_STORAGE: - return AmazonS3Service(**get_config_s3()) + s3_config = get_config_s3() + + if "aws_access_key_id" not in s3_config: + raise RuntimeError("Invalid S3 Config: Missing aws_access_key_id") + + if "aws_secret_access_key" not in s3_config: + raise RuntimeError("Invalid S3 Config: Missing aws_secret_access_key") + + if "region_name" not in s3_config: + raise RuntimeError("Invalid S3 Config: Missing region_name") + + s3_client = boto3.client( + "s3", + region_name=s3_config["region_name"], + aws_access_key_id=s3_config["aws_access_key_id"], + aws_secret_access_key=s3_config["aws_secret_access_key"], + ) + + return AmazonS3Service( + s3_client, + s3_config["bucket_name"], + ) raise RuntimeError("Invalid Storage Type") diff --git a/storage_service/depends/depend_virus_checker_service.py b/storage_service/depends/depend_virus_checker_service.py index 26a9594..57fa4cc 100644 --- a/storage_service/depends/depend_virus_checker_service.py +++ b/storage_service/depends/depend_virus_checker_service.py @@ -10,6 +10,7 @@ from storage_service.service.virus_checker.virus_total_service import ( from storage_service.utils.enums.virus_checker_type import VirusCheckerType from dotenv import load_dotenv +from virustotal_python import Virustotal import os from functools import cache @@ -19,13 +20,12 @@ from functools import cache def dependency_virus_checker_service() -> VirusCheckerService: load_dotenv() - virus_checker_config = get_virus_checker_api_key() + try: + type = VirusCheckerType(os.environ["VIRUS_CHECKER_TYPE"]) + except ValueError: + raise RuntimeError("Invalid Virus Checker Type") - if not virus_checker_config["api_key"]: - raise RuntimeError("Virus Checker API Key not found") - - virus_checker_type_var = os.environ.get("VIRUS_CHECKER_TYPE") - if VirusCheckerType(virus_checker_type_var) == VirusCheckerType.TOTAL_VIRUS: - return VirusTotalService(**get_virus_checker_api_key()) - - raise RuntimeError("Invalid Virus Checker Type") + match type: + case VirusCheckerType.TOTAL_VIRUS: + virus_checker = Virustotal(get_virus_checker_api_key()) + return VirusTotalService(virus_checker) diff --git a/storage_service/model/__init__.py b/storage_service/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/storage_service/model/health_check/__init__.py b/storage_service/model/health_check/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/storage_service/model/health_check/health_check_response.py b/storage_service/model/health_check/health_check_response.py new file mode 100644 index 0000000..07bc324 --- /dev/null +++ b/storage_service/model/health_check/health_check_response.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class HealthCheckResponse(BaseModel): + status: str diff --git a/storage_service/model/storage/__init__.py b/storage_service/model/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/storage_service/model/storage/new_file_request.py b/storage_service/model/storage/new_file_request.py new file mode 100644 index 0000000..74c852f --- /dev/null +++ b/storage_service/model/storage/new_file_request.py @@ -0,0 +1,9 @@ +from storage_service.utils.enums.file_type import FileType + +from pydantic import BaseModel + + +class NewFileURLRequest(BaseModel): + file_key: str + file_postfix: str + file_type: FileType diff --git a/storage_service/model/storage/process_file_request.py b/storage_service/model/storage/process_file_request.py new file mode 100644 index 0000000..de9f4ba --- /dev/null +++ b/storage_service/model/storage/process_file_request.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class ProcessFileRequest(BaseModel): + file_key: str + file_postfix: str diff --git a/storage_service/model/storage/signed_url_response.py b/storage_service/model/storage/signed_url_response.py new file mode 100644 index 0000000..4e00d5b --- /dev/null +++ b/storage_service/model/storage/signed_url_response.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class SignedUrlResponse(BaseModel): + signed_url: str + expires_in: int diff --git a/storage_service/service/storage/amazon_s3_service.py b/storage_service/service/storage/amazon_s3_service.py index dea0402..edc6841 100644 --- a/storage_service/service/storage/amazon_s3_service.py +++ b/storage_service/service/storage/amazon_s3_service.py @@ -3,46 +3,58 @@ from __future__ import annotations from storage_service.depends.depend_virus_checker_service import ( dependency_virus_checker_service, ) +from storage_service.model.storage.signed_url_response import SignedUrlResponse from storage_service.service.storage.storage_service import StorageService +from storage_service.service.virus_checker.virus_checker_service import ( + VirusCheckerService, +) from storage_service.utils.enums.file_type import FileType from storage_service.utils.file_handler import FILE_HANDLER -import boto3 +from botocore.client import BaseClient import io -from typing import Any class AmazonS3Service(StorageService): - virus_checker_service = dependency_virus_checker_service() + virus_checker_service: VirusCheckerService - def __init__(self, **kwargs): - super().__init__(**kwargs) + s3_client: BaseClient + bucket_name: str - self.__validate_config(**kwargs) + expires_in: int = 3600 - self.bucket_name = kwargs.get("bucket_name") - self.region_name = kwargs.get("region_name") + def __init__( + self, + s3_client: BaseClient, + bucket_name: str, + virus_checker_service=dependency_virus_checker_service(), + **kwargs, + ): + self.virus_checker_service = virus_checker_service - self.expires_in = kwargs.get("expires_in") + if s3_client is None: + raise RuntimeError("Invalid S3 Config: Missing s3_client") + self.s3_client = s3_client - self.s3 = boto3.client( - "s3", - aws_access_key_id=kwargs.get("aws_access_key_id"), - aws_secret_access_key=kwargs.get("aws_secret_access_key"), - region_name=kwargs.get("region_name"), + if bucket_name is None: + raise RuntimeError("Invalid S3 Config: Missing bucket_name") + self.bucket_name = bucket_name + + if "expires_in" in kwargs: + self.expires_in = kwargs["expires_in"] + + def get_temp_upload_link(self, file_name, file_type: FileType) -> SignedUrlResponse: + return SignedUrlResponse( + signed_url=self._get_presigned_write_url(file_name, file_type), + expires_in=self.expires_in, ) - def get_temp_upload_link( - self, file_name, file_type: FileType - ) -> dict[str, str | Any]: - return { - "presigned_url": self._get_presigned_write_url(file_name, file_type), - "file_key": self._get_object_url(file_name), - } - - def get_temp_read_link(self, file_name) -> dict[str, str | None]: - return {"presigned_url": self._get_presigned_read_url(file_name)} + def get_temp_read_link(self, file_name) -> SignedUrlResponse: + return SignedUrlResponse( + signed_url=self._get_presigned_read_url(file_name), + expires_in=self.expires_in, + ) def delete_file(self, file_name: str) -> None: self._delete_file(file_name) @@ -57,11 +69,8 @@ class AmazonS3Service(StorageService): self._upload_file(file_name, handler(file_bytes)) - def _get_object_url(self, file_name: str) -> str: - return f"https://{self.bucket_name}.s3.{self.region_name}.amazonaws.com/{file_name}" - def _get_presigned_write_url(self, file_name, file_type: FileType) -> str: - return self.s3.generate_presigned_url( + return self.s3_client.generate_presigned_url( "put_object", Params={ "Bucket": self.bucket_name, @@ -72,12 +81,12 @@ class AmazonS3Service(StorageService): ) def _get_presigned_read_url(self, file_name) -> str | None: - result = self.s3.list_objects(Bucket=self.bucket_name, Prefix=file_name) + result = self.s3_client.list_objects(Bucket=self.bucket_name, Prefix=file_name) if "Contents" in result and file_name in map( lambda x: x["Key"], result["Contents"] ): - return self.s3.generate_presigned_url( + return self.s3_client.generate_presigned_url( "get_object", Params={"Bucket": self.bucket_name, "Key": file_name}, ExpiresIn=self.expires_in, @@ -86,28 +95,15 @@ class AmazonS3Service(StorageService): def _get_file_obj(self, file_name: str) -> io.BytesIO: return io.BytesIO( - self.s3.get_object(Bucket=self.bucket_name, Key=file_name)["Body"].read() + self.s3_client.get_object(Bucket=self.bucket_name, Key=file_name)[ + "Body" + ].read() ) def _upload_file(self, file_name: str, file_bytes: io.BytesIO) -> None: - self.s3.upload_fileobj(file_bytes, Bucket=self.bucket_name, Key=file_name) + self.s3_client.upload_fileobj( + file_bytes, Bucket=self.bucket_name, Key=file_name + ) def _delete_file(self, file_name: str) -> None: - self.s3.delete_object(Bucket=self.bucket_name, Key=file_name) - - @staticmethod - def __validate_config(**kwargs): - if not kwargs.get("bucket_name"): - raise RuntimeError("bucket_name is required") - - if not kwargs.get("aws_access_key_id"): - raise RuntimeError("aws_access_key_id is required") - - if not kwargs.get("aws_secret_access_key"): - raise RuntimeError("aws_secret_access_key is required") - - if not kwargs.get("region_name"): - raise RuntimeError("region_name is required") - - if not kwargs.get("bucket_name"): - raise RuntimeError("bucket_name is required") + self.s3_client.delete_object(Bucket=self.bucket_name, Key=file_name) diff --git a/storage_service/service/storage/storage_service.py b/storage_service/service/storage/storage_service.py index bf81cee..40c84b1 100644 --- a/storage_service/service/storage/storage_service.py +++ b/storage_service/service/storage/storage_service.py @@ -1,23 +1,18 @@ from __future__ import annotations +from storage_service.model.storage.signed_url_response import SignedUrlResponse from storage_service.utils.enums.file_type import FileType from abc import ABC, abstractmethod -from typing import Any class StorageService(ABC): - def __init__(self, **kwargs): + @abstractmethod + def get_temp_upload_link(self, file_name, file_type: FileType) -> SignedUrlResponse: pass @abstractmethod - def get_temp_upload_link( - self, file_name, file_type: FileType - ) -> dict[str, str | Any]: - pass - - @abstractmethod - def get_temp_read_link(self, file_name) -> dict[str, str | None]: + def get_temp_read_link(self, file_name) -> SignedUrlResponse: pass @abstractmethod diff --git a/storage_service/service/virus_checker/virus_total_service.py b/storage_service/service/virus_checker/virus_total_service.py index 47eed43..e96e551 100644 --- a/storage_service/service/virus_checker/virus_total_service.py +++ b/storage_service/service/virus_checker/virus_total_service.py @@ -8,31 +8,33 @@ from io import BytesIO class VirusTotalService(VirusCheckerService): - def __init__(self, api_key: str): - self.api_key = api_key + virus_checker: Virustotal + + def __init__(self, virus_checker: Virustotal): + self.virus_checker = virus_checker def check_virus(self, file_data: BytesIO) -> bool: + file_id = self._upload_file(file_data) + file_attributes = self._get_analysis(file_id) + + return self._is_valid_file(file_attributes) + + def _upload_file(self, file_data: BytesIO) -> str: files = {"file": ("image_file", file_data)} - with Virustotal(self.api_key) as vtotal: - resp = vtotal.request("files", files=files, method="POST") + resp = self.virus_checker.request("files", files=files, method="POST") - file_attributes = self._get_analysis(resp.json()["data"]["id"]) - - return self._is_valid_file(file_attributes["data"]["attributes"]["stats"]) + return resp.data["id"] def _get_analysis(self, file_id: str) -> dict: - with Virustotal(self.api_key) as vtotal: - resp = vtotal.request(f"analyses/{file_id}") + resp = self.virus_checker.request(f"analyses/{file_id}") - return resp.json() + return resp.json()["data"]["attributes"]["stats"] @staticmethod def _is_valid_file(file_stats: dict) -> bool: - if "malicious" in file_stats and file_stats["malicious"] > 0: - return False - - if "suspicious" in file_stats and file_stats["suspicious"] > 0: - return False - - return True + match file_stats: + case {"malicious": 0, "suspicious": 0, "undetected": 0, "harmless": 0}: + return True + case _: + return False diff --git a/storage_service/utils/file_name_hash.py b/storage_service/utils/file_name_hash.py index 9a1cae6..8c993e2 100644 --- a/storage_service/utils/file_name_hash.py +++ b/storage_service/utils/file_name_hash.py @@ -2,8 +2,8 @@ import base64 from hashlib import md5 -def file_name_hash(username: str, file_postfix: str) -> str: - hashed_username = md5(username.encode("utf-8")).digest() - hashed_username = base64.b64encode(hashed_username).decode() +def file_name_hash(file_key: str, file_postfix: str) -> str: + hashed_file_key = md5(file_key.encode("utf-8")).digest() + hashed_file_key = base64.b64encode(hashed_file_key).decode() - return f"{hashed_username}_{file_postfix}" + return f"{hashed_file_key}_{file_postfix}"