Refactors Services and API Endpoints
This commit is contained in:
@@ -3,7 +3,12 @@ from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
|
||||
def get_virus_checker_api_key():
|
||||
def get_virus_checker_api_key() -> str:
|
||||
load_dotenv()
|
||||
|
||||
return {"api_key": os.environ.get("VIRUS_CHECKER_API_KEY")}
|
||||
api_key = os.environ.get("VIRUS_CHECKER_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
raise RuntimeError("Virus Checker API Key not found")
|
||||
|
||||
return api_key
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from storage_service.model.health_check.health_check_response import (
|
||||
HealthCheckResponse,
|
||||
)
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi_utils.cbv import cbv
|
||||
|
||||
health_router = APIRouter()
|
||||
health_router = APIRouter(tags=["health"])
|
||||
|
||||
|
||||
@cbv(health_router)
|
||||
class HealthCheckerController:
|
||||
@health_router.get("/health", status_code=200)
|
||||
def health(self) -> dict[str, str]:
|
||||
return {"status": "healthy"}
|
||||
def health(self) -> HealthCheckResponse:
|
||||
return HealthCheckResponse(status="healthy")
|
||||
|
||||
@@ -4,18 +4,20 @@ from storage_service.depends.depend_queue import dependency_queue
|
||||
from storage_service.depends.depend_s3_service import (
|
||||
dependency_storage_service,
|
||||
)
|
||||
from storage_service.model.storage.new_file_request import NewFileURLRequest
|
||||
from storage_service.model.storage.process_file_request import (
|
||||
ProcessFileRequest,
|
||||
)
|
||||
from storage_service.model.storage.signed_url_response import SignedUrlResponse
|
||||
from storage_service.service.storage.storage_service import StorageService
|
||||
from storage_service.utils.enums.file_type import FileType
|
||||
from storage_service.utils.file_name_hash import file_name_hash
|
||||
from storage_service.worker.storage_file_worker import storage_file_worker
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi_utils.cbv import cbv
|
||||
from rq import Queue
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
s3_router = APIRouter()
|
||||
s3_router = APIRouter(tags=["storage"])
|
||||
|
||||
|
||||
@cbv(s3_router)
|
||||
@@ -25,31 +27,30 @@ class StorageController:
|
||||
dependency_storage_service, use_cache=True
|
||||
)
|
||||
|
||||
@s3_router.post("/file/", status_code=200)
|
||||
def new_file_url(
|
||||
self,
|
||||
username: Annotated[str, Body(embed=True)],
|
||||
file_postfix: Annotated[str, Body(embed=True)],
|
||||
file_type: Annotated[FileType, Body(embed=True)],
|
||||
) -> dict[str, str]:
|
||||
@s3_router.post("/file", status_code=200)
|
||||
def new_file_url(self, new_file_request: NewFileURLRequest) -> SignedUrlResponse:
|
||||
hashed_file_name = file_name_hash(
|
||||
new_file_request.file_key, new_file_request.file_postfix
|
||||
)
|
||||
|
||||
return self.storage_service.get_temp_upload_link(
|
||||
file_name_hash(username, file_postfix), file_type
|
||||
hashed_file_name, new_file_request.file_type
|
||||
)
|
||||
|
||||
@s3_router.get("/file/", status_code=200)
|
||||
def file_url(self, username: str, file_postfix: str) -> dict[str, str | None]:
|
||||
@s3_router.get("/file", status_code=200)
|
||||
def file_url(self, file_key: str, file_postfix: str) -> SignedUrlResponse:
|
||||
return self.storage_service.get_temp_read_link(
|
||||
file_name_hash(username, file_postfix)
|
||||
file_name_hash(file_key, file_postfix)
|
||||
)
|
||||
|
||||
@s3_router.delete("/file/", status_code=204)
|
||||
def delete_file(self, username: str, file_postfix: str):
|
||||
return self.storage_service.delete_file(file_name_hash(username, file_postfix))
|
||||
@s3_router.delete("/file", status_code=204)
|
||||
def delete_file(self, file_key: str, file_postfix: str):
|
||||
return self.storage_service.delete_file(file_name_hash(file_key, file_postfix))
|
||||
|
||||
@s3_router.post("/file/process", status_code=200)
|
||||
def process_file(
|
||||
self,
|
||||
username: Annotated[str, Body(embed=True)],
|
||||
file_postfix: Annotated[str, Body(embed=True)],
|
||||
):
|
||||
self.queue.enqueue(storage_file_worker, username, file_postfix)
|
||||
def process_file(self, process_file_request: ProcessFileRequest):
|
||||
self.queue.enqueue(
|
||||
storage_file_worker,
|
||||
process_file_request.file_key,
|
||||
process_file_request.file_postfix,
|
||||
)
|
||||
|
||||
@@ -3,6 +3,8 @@ from storage_service.service.storage.amazon_s3_service import AmazonS3Service
|
||||
from storage_service.service.storage.storage_service import StorageService
|
||||
from storage_service.utils.enums.storage_type import StorageType
|
||||
|
||||
import boto3
|
||||
import botocore.client
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import os
|
||||
@@ -14,6 +16,27 @@ def dependency_storage_service() -> StorageService:
|
||||
load_dotenv()
|
||||
|
||||
if StorageType(os.environ["STORAGE_TYPE"]) == StorageType.S3_STORAGE:
|
||||
return AmazonS3Service(**get_config_s3())
|
||||
s3_config = get_config_s3()
|
||||
|
||||
if "aws_access_key_id" not in s3_config:
|
||||
raise RuntimeError("Invalid S3 Config: Missing aws_access_key_id")
|
||||
|
||||
if "aws_secret_access_key" not in s3_config:
|
||||
raise RuntimeError("Invalid S3 Config: Missing aws_secret_access_key")
|
||||
|
||||
if "region_name" not in s3_config:
|
||||
raise RuntimeError("Invalid S3 Config: Missing region_name")
|
||||
|
||||
s3_client = boto3.client(
|
||||
"s3",
|
||||
region_name=s3_config["region_name"],
|
||||
aws_access_key_id=s3_config["aws_access_key_id"],
|
||||
aws_secret_access_key=s3_config["aws_secret_access_key"],
|
||||
)
|
||||
|
||||
return AmazonS3Service(
|
||||
s3_client,
|
||||
s3_config["bucket_name"],
|
||||
)
|
||||
|
||||
raise RuntimeError("Invalid Storage Type")
|
||||
|
||||
@@ -10,6 +10,7 @@ from storage_service.service.virus_checker.virus_total_service import (
|
||||
from storage_service.utils.enums.virus_checker_type import VirusCheckerType
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from virustotal_python import Virustotal
|
||||
|
||||
import os
|
||||
from functools import cache
|
||||
@@ -19,13 +20,12 @@ from functools import cache
|
||||
def dependency_virus_checker_service() -> VirusCheckerService:
|
||||
load_dotenv()
|
||||
|
||||
virus_checker_config = get_virus_checker_api_key()
|
||||
try:
|
||||
type = VirusCheckerType(os.environ["VIRUS_CHECKER_TYPE"])
|
||||
except ValueError:
|
||||
raise RuntimeError("Invalid Virus Checker Type")
|
||||
|
||||
if not virus_checker_config["api_key"]:
|
||||
raise RuntimeError("Virus Checker API Key not found")
|
||||
|
||||
virus_checker_type_var = os.environ.get("VIRUS_CHECKER_TYPE")
|
||||
if VirusCheckerType(virus_checker_type_var) == VirusCheckerType.TOTAL_VIRUS:
|
||||
return VirusTotalService(**get_virus_checker_api_key())
|
||||
|
||||
raise RuntimeError("Invalid Virus Checker Type")
|
||||
match type:
|
||||
case VirusCheckerType.TOTAL_VIRUS:
|
||||
virus_checker = Virustotal(get_virus_checker_api_key())
|
||||
return VirusTotalService(virus_checker)
|
||||
|
||||
0
storage_service/model/__init__.py
Normal file
0
storage_service/model/__init__.py
Normal file
0
storage_service/model/health_check/__init__.py
Normal file
0
storage_service/model/health_check/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
status: str
|
||||
0
storage_service/model/storage/__init__.py
Normal file
0
storage_service/model/storage/__init__.py
Normal file
9
storage_service/model/storage/new_file_request.py
Normal file
9
storage_service/model/storage/new_file_request.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from storage_service.utils.enums.file_type import FileType
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class NewFileURLRequest(BaseModel):
|
||||
file_key: str
|
||||
file_postfix: str
|
||||
file_type: FileType
|
||||
6
storage_service/model/storage/process_file_request.py
Normal file
6
storage_service/model/storage/process_file_request.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ProcessFileRequest(BaseModel):
|
||||
file_key: str
|
||||
file_postfix: str
|
||||
6
storage_service/model/storage/signed_url_response.py
Normal file
6
storage_service/model/storage/signed_url_response.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SignedUrlResponse(BaseModel):
|
||||
signed_url: str
|
||||
expires_in: int
|
||||
@@ -3,46 +3,58 @@ from __future__ import annotations
|
||||
from storage_service.depends.depend_virus_checker_service import (
|
||||
dependency_virus_checker_service,
|
||||
)
|
||||
from storage_service.model.storage.signed_url_response import SignedUrlResponse
|
||||
from storage_service.service.storage.storage_service import StorageService
|
||||
from storage_service.service.virus_checker.virus_checker_service import (
|
||||
VirusCheckerService,
|
||||
)
|
||||
from storage_service.utils.enums.file_type import FileType
|
||||
from storage_service.utils.file_handler import FILE_HANDLER
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AmazonS3Service(StorageService):
|
||||
virus_checker_service = dependency_virus_checker_service()
|
||||
virus_checker_service: VirusCheckerService
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
s3_client: BaseClient
|
||||
bucket_name: str
|
||||
|
||||
self.__validate_config(**kwargs)
|
||||
expires_in: int = 3600
|
||||
|
||||
self.bucket_name = kwargs.get("bucket_name")
|
||||
self.region_name = kwargs.get("region_name")
|
||||
def __init__(
|
||||
self,
|
||||
s3_client: BaseClient,
|
||||
bucket_name: str,
|
||||
virus_checker_service=dependency_virus_checker_service(),
|
||||
**kwargs,
|
||||
):
|
||||
self.virus_checker_service = virus_checker_service
|
||||
|
||||
self.expires_in = kwargs.get("expires_in")
|
||||
if s3_client is None:
|
||||
raise RuntimeError("Invalid S3 Config: Missing s3_client")
|
||||
self.s3_client = s3_client
|
||||
|
||||
self.s3 = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=kwargs.get("aws_access_key_id"),
|
||||
aws_secret_access_key=kwargs.get("aws_secret_access_key"),
|
||||
region_name=kwargs.get("region_name"),
|
||||
if bucket_name is None:
|
||||
raise RuntimeError("Invalid S3 Config: Missing bucket_name")
|
||||
self.bucket_name = bucket_name
|
||||
|
||||
if "expires_in" in kwargs:
|
||||
self.expires_in = kwargs["expires_in"]
|
||||
|
||||
def get_temp_upload_link(self, file_name, file_type: FileType) -> SignedUrlResponse:
|
||||
return SignedUrlResponse(
|
||||
signed_url=self._get_presigned_write_url(file_name, file_type),
|
||||
expires_in=self.expires_in,
|
||||
)
|
||||
|
||||
def get_temp_upload_link(
|
||||
self, file_name, file_type: FileType
|
||||
) -> dict[str, str | Any]:
|
||||
return {
|
||||
"presigned_url": self._get_presigned_write_url(file_name, file_type),
|
||||
"file_key": self._get_object_url(file_name),
|
||||
}
|
||||
|
||||
def get_temp_read_link(self, file_name) -> dict[str, str | None]:
|
||||
return {"presigned_url": self._get_presigned_read_url(file_name)}
|
||||
def get_temp_read_link(self, file_name) -> SignedUrlResponse:
|
||||
return SignedUrlResponse(
|
||||
signed_url=self._get_presigned_read_url(file_name),
|
||||
expires_in=self.expires_in,
|
||||
)
|
||||
|
||||
def delete_file(self, file_name: str) -> None:
|
||||
self._delete_file(file_name)
|
||||
@@ -57,11 +69,8 @@ class AmazonS3Service(StorageService):
|
||||
|
||||
self._upload_file(file_name, handler(file_bytes))
|
||||
|
||||
def _get_object_url(self, file_name: str) -> str:
|
||||
return f"https://{self.bucket_name}.s3.{self.region_name}.amazonaws.com/{file_name}"
|
||||
|
||||
def _get_presigned_write_url(self, file_name, file_type: FileType) -> str:
|
||||
return self.s3.generate_presigned_url(
|
||||
return self.s3_client.generate_presigned_url(
|
||||
"put_object",
|
||||
Params={
|
||||
"Bucket": self.bucket_name,
|
||||
@@ -72,12 +81,12 @@ class AmazonS3Service(StorageService):
|
||||
)
|
||||
|
||||
def _get_presigned_read_url(self, file_name) -> str | None:
|
||||
result = self.s3.list_objects(Bucket=self.bucket_name, Prefix=file_name)
|
||||
result = self.s3_client.list_objects(Bucket=self.bucket_name, Prefix=file_name)
|
||||
|
||||
if "Contents" in result and file_name in map(
|
||||
lambda x: x["Key"], result["Contents"]
|
||||
):
|
||||
return self.s3.generate_presigned_url(
|
||||
return self.s3_client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": self.bucket_name, "Key": file_name},
|
||||
ExpiresIn=self.expires_in,
|
||||
@@ -86,28 +95,15 @@ class AmazonS3Service(StorageService):
|
||||
|
||||
def _get_file_obj(self, file_name: str) -> io.BytesIO:
|
||||
return io.BytesIO(
|
||||
self.s3.get_object(Bucket=self.bucket_name, Key=file_name)["Body"].read()
|
||||
self.s3_client.get_object(Bucket=self.bucket_name, Key=file_name)[
|
||||
"Body"
|
||||
].read()
|
||||
)
|
||||
|
||||
def _upload_file(self, file_name: str, file_bytes: io.BytesIO) -> None:
|
||||
self.s3.upload_fileobj(file_bytes, Bucket=self.bucket_name, Key=file_name)
|
||||
self.s3_client.upload_fileobj(
|
||||
file_bytes, Bucket=self.bucket_name, Key=file_name
|
||||
)
|
||||
|
||||
def _delete_file(self, file_name: str) -> None:
|
||||
self.s3.delete_object(Bucket=self.bucket_name, Key=file_name)
|
||||
|
||||
@staticmethod
|
||||
def __validate_config(**kwargs):
|
||||
if not kwargs.get("bucket_name"):
|
||||
raise RuntimeError("bucket_name is required")
|
||||
|
||||
if not kwargs.get("aws_access_key_id"):
|
||||
raise RuntimeError("aws_access_key_id is required")
|
||||
|
||||
if not kwargs.get("aws_secret_access_key"):
|
||||
raise RuntimeError("aws_secret_access_key is required")
|
||||
|
||||
if not kwargs.get("region_name"):
|
||||
raise RuntimeError("region_name is required")
|
||||
|
||||
if not kwargs.get("bucket_name"):
|
||||
raise RuntimeError("bucket_name is required")
|
||||
self.s3_client.delete_object(Bucket=self.bucket_name, Key=file_name)
|
||||
|
||||
@@ -1,23 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from storage_service.model.storage.signed_url_response import SignedUrlResponse
|
||||
from storage_service.utils.enums.file_type import FileType
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StorageService(ABC):
|
||||
def __init__(self, **kwargs):
|
||||
@abstractmethod
|
||||
def get_temp_upload_link(self, file_name, file_type: FileType) -> SignedUrlResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_temp_upload_link(
|
||||
self, file_name, file_type: FileType
|
||||
) -> dict[str, str | Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_temp_read_link(self, file_name) -> dict[str, str | None]:
|
||||
def get_temp_read_link(self, file_name) -> SignedUrlResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -8,31 +8,33 @@ from io import BytesIO
|
||||
|
||||
|
||||
class VirusTotalService(VirusCheckerService):
|
||||
def __init__(self, api_key: str):
|
||||
self.api_key = api_key
|
||||
virus_checker: Virustotal
|
||||
|
||||
def __init__(self, virus_checker: Virustotal):
|
||||
self.virus_checker = virus_checker
|
||||
|
||||
def check_virus(self, file_data: BytesIO) -> bool:
|
||||
file_id = self._upload_file(file_data)
|
||||
file_attributes = self._get_analysis(file_id)
|
||||
|
||||
return self._is_valid_file(file_attributes)
|
||||
|
||||
def _upload_file(self, file_data: BytesIO) -> str:
|
||||
files = {"file": ("image_file", file_data)}
|
||||
|
||||
with Virustotal(self.api_key) as vtotal:
|
||||
resp = vtotal.request("files", files=files, method="POST")
|
||||
resp = self.virus_checker.request("files", files=files, method="POST")
|
||||
|
||||
file_attributes = self._get_analysis(resp.json()["data"]["id"])
|
||||
|
||||
return self._is_valid_file(file_attributes["data"]["attributes"]["stats"])
|
||||
return resp.data["id"]
|
||||
|
||||
def _get_analysis(self, file_id: str) -> dict:
|
||||
with Virustotal(self.api_key) as vtotal:
|
||||
resp = vtotal.request(f"analyses/{file_id}")
|
||||
resp = self.virus_checker.request(f"analyses/{file_id}")
|
||||
|
||||
return resp.json()
|
||||
return resp.json()["data"]["attributes"]["stats"]
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_file(file_stats: dict) -> bool:
|
||||
if "malicious" in file_stats and file_stats["malicious"] > 0:
|
||||
return False
|
||||
|
||||
if "suspicious" in file_stats and file_stats["suspicious"] > 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
match file_stats:
|
||||
case {"malicious": 0, "suspicious": 0, "undetected": 0, "harmless": 0}:
|
||||
return True
|
||||
case _:
|
||||
return False
|
||||
|
||||
@@ -2,8 +2,8 @@ import base64
|
||||
from hashlib import md5
|
||||
|
||||
|
||||
def file_name_hash(username: str, file_postfix: str) -> str:
|
||||
hashed_username = md5(username.encode("utf-8")).digest()
|
||||
hashed_username = base64.b64encode(hashed_username).decode()
|
||||
def file_name_hash(file_key: str, file_postfix: str) -> str:
|
||||
hashed_file_key = md5(file_key.encode("utf-8")).digest()
|
||||
hashed_file_key = base64.b64encode(hashed_file_key).decode()
|
||||
|
||||
return f"{hashed_username}_{file_postfix}"
|
||||
return f"{hashed_file_key}_{file_postfix}"
|
||||
|
||||
Reference in New Issue
Block a user