Adds New Delete File Endpoint and Refactors Endpoints, Adds VirusChecker

This commit is contained in:
2023-09-10 02:21:41 -03:00
parent ab65be9710
commit c826000954
13 changed files with 149 additions and 10 deletions

View File

@@ -0,0 +1,11 @@
from dotenv import load_dotenv
import os
def get_virus_checker_api_key():
load_dotenv()
return {
"api_key": os.environ.get("VIRUS_CHECKER_API_KEY")
}

View File

@@ -4,12 +4,12 @@ from storage_service.depends.depend_queue import dependency_queue
from storage_service.depends.depend_s3_service import (
dependency_storage_service,
)
from storage_service.service.storage_service import StorageService
from storage_service.service.storage.storage_service import StorageService
from storage_service.utils.enums.file_type import FileType
from storage_service.utils.file_name_hash import file_name_hash
from storage_service.worker.storage_file_worker import storage_file_worker
from fastapi import Body, Depends, Form
from fastapi import Body, Depends
from fastapi_utils.cbv import cbv
from fastapi_utils.inferring_router import InferringRouter
from rq import Queue
@@ -26,7 +26,7 @@ class StorageController:
dependency_storage_service, use_cache=True
)
@s3_router.post("/new_file_url/", status_code=200)
@s3_router.post("/file/", status_code=200)
def new_file_url(
self,
username: Annotated[str, Body(embed=True)],
@@ -37,16 +37,24 @@ class StorageController:
file_name_hash(username, file_postfix), file_type
)
@s3_router.get("/file_url/", status_code=200)
@s3_router.get("/file/", status_code=200)
def file_url(self, username: str, file_postfix: str) -> dict[str, str | None]:
return self.storage_service.get_temp_read_link(
file_name_hash(username, file_postfix)
)
@s3_router.post("/process_file/", status_code=200)
@s3_router.delete("/file/", status_code=204)
def delete_file(self, username: str, file_postfix: str):
return self.storage_service.delete_file(
file_name_hash(username, file_postfix)
)
@s3_router.post("/file/process", status_code=200)
def process_file(
self,
username: Annotated[str, Body(embed=True)],
file_postfix: Annotated[str, Body(embed=True)],
):
self.queue.enqueue(storage_file_worker, username, file_postfix)

View File

@@ -1,6 +1,6 @@
from storage_service.config.config_s3 import get_config_s3
from storage_service.service.amazon_s3_service import AmazonS3Service
from storage_service.service.storage_service import StorageService
from storage_service.service.storage.amazon_s3_service import AmazonS3Service
from storage_service.service.storage.storage_service import StorageService
from storage_service.utils.enums.storage_type import StorageType
from dotenv import load_dotenv

View File

@@ -0,0 +1,26 @@
import os
from functools import cache
from storage_service.config.config_virus_checker import get_virus_checker_api_key
from storage_service.service.virus_checker.virus_total_service import VirusTotalService
from storage_service.service.virus_checker.virus_checker_service import VirusCheckerService
from dotenv import load_dotenv
from storage_service.utils.enums.virus_checker_type import VirusCheckerType
@cache
def dependency_virus_checker_service() -> VirusCheckerService:
load_dotenv()
virus_checker_config = get_virus_checker_api_key()
if not virus_checker_config["api_key"]:
raise RuntimeError("Virus Checker API Key not found")
virus_checker_type_var = os.environ.get("VIRUS_CHECKER_TYPE")
if VirusCheckerType(virus_checker_type_var) == VirusCheckerType.TOTAL_VIRUS:
return VirusTotalService(**get_virus_checker_api_key())
raise RuntimeError("Invalid Virus Checker Type")

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from storage_service.service.storage_service import StorageService
from storage_service.depends.depend_virus_checker_service import dependency_virus_checker_service
from storage_service.service.storage.storage_service import StorageService
from storage_service.utils.enums.file_type import FileType
from storage_service.utils.file_handler import FILE_HANDLER
@@ -11,6 +12,9 @@ from typing import Any
class AmazonS3Service(StorageService):
virus_checker_service = dependency_virus_checker_service()
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -39,8 +43,15 @@ class AmazonS3Service(StorageService):
def get_temp_read_link(self, file_name) -> dict[str, str | None]:
return {"presigned_url": self._get_presigned_read_url(file_name)}
def delete_file(self, file_name: str) -> None:
self._delete_file(file_name)
def process_file(self, file_name: str, file_type: FileType = FileType.PNG) -> None:
file_bytes = self._get_file_obj(file_name)
if not self.virus_checker_service.check_virus(file_bytes):
self._delete_file(file_name)
handler = FILE_HANDLER[file_type]["handler"]
self._upload_file(file_name, handler(file_bytes))
@@ -78,6 +89,9 @@ class AmazonS3Service(StorageService):
def _upload_file(self, file_name: str, file_bytes: io.BytesIO) -> None:
self.s3.upload_fileobj(file_bytes, Bucket=self.bucket_name, Key=file_name)
def _delete_file(self, file_name: str) -> None:
self.s3.delete_object(Bucket=self.bucket_name, Key=file_name)
@staticmethod
def __validate_config(**kwargs):
if not kwargs.get("bucket_name"):

View File

@@ -20,6 +20,10 @@ class StorageService(ABC):
def get_temp_read_link(self, file_name) -> dict[str, str | None]:
pass
@abstractmethod
def delete_file(self, file_name: str) -> None:
pass
@abstractmethod
def process_file(self, file_name: str, file_type: FileType) -> None:
pass

View File

@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
from io import BytesIO
class VirusCheckerService(ABC):
@abstractmethod
def check_virus(self, file_data: BytesIO) -> bool:
pass

View File

@@ -0,0 +1,35 @@
from io import BytesIO
from virustotal_python import Virustotal
from storage_service.service.virus_checker.virus_checker_service import VirusCheckerService
class VirusTotalService(VirusCheckerService):
def __init__(self, api_key: str):
self.api_key = api_key
def check_virus(self, file_data: BytesIO) -> bool:
files = {"file": ("image_file", file_data)}
with Virustotal(self.api_key) as vtotal:
resp = vtotal.request("files", files=files, method="POST")
file_attributes = self._get_analysis(resp.json()["data"]["id"])
return self._is_valid_file(file_attributes["data"]["attributes"]["stats"])
def _get_analysis(self, file_id: str) -> dict:
with Virustotal(self.api_key) as vtotal:
resp = vtotal.request(f"analyses/{file_id}")
return resp.json()
@staticmethod
def _is_valid_file(file_stats: dict) -> bool:
if 'malicious' in file_stats and file_stats['malicious'] > 0:
return False
if 'suspicious' in file_stats and file_stats['suspicious'] > 0:
return False
return True

View File

@@ -0,0 +1,5 @@
from enum import Enum
class VirusCheckerType(Enum):
TOTAL_VIRUS = "total_virus"