Merge pull request #16 from HideyoshiSolutions/new-api-implementation
New API and Tests Implementation
This commit is contained in:
13
.githooks/pre-commit-config.yaml
Normal file
13
.githooks/pre-commit-config.yaml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 24.4.2
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
args: [--config=pyproject.toml]
|
||||||
|
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: 5.13.2
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
files: "\\.(py)$"
|
||||||
|
args: [--settings-path=pyproject.toml]
|
||||||
5
.githooks/set-hooks.sh
Normal file
5
.githooks/set-hooks.sh
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
|
||||||
|
pre-commit install --config .githooks/pre-commit-config.yaml
|
||||||
|
pre-commit autoupdate --config .githooks/pre-commit-config.yaml
|
||||||
27
.github/workflows/run-tests.yml
vendored
Normal file
27
.github/workflows/run-tests.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: ci
|
||||||
|
|
||||||
|
on:
|
||||||
|
push
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
run-tests:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install poetry
|
||||||
|
poetry install
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
poetry run python -m unittest
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
|||||||
.env*
|
.env*
|
||||||
|
.coverage*
|
||||||
|
|
||||||
.idea
|
.idea
|
||||||
|
|
||||||
|
|||||||
65
Dockerfile
65
Dockerfile
@@ -1,64 +1,11 @@
|
|||||||
# `python-base` sets up all our shared environment variables
|
FROM python:3.12
|
||||||
FROM python:3.12-slim as python-base
|
|
||||||
|
|
||||||
# python
|
WORKDIR /app
|
||||||
ENV PYTHONUNBUFFERED=1 \
|
|
||||||
# prevents python creating .pyc files
|
|
||||||
PYTHONDONTWRITEBYTECODE=1 \
|
|
||||||
\
|
|
||||||
# pip
|
|
||||||
PIP_NO_CACHE_DIR=off \
|
|
||||||
PIP_DISABLE_PIP_VERSION_CHECK=on \
|
|
||||||
PIP_DEFAULT_TIMEOUT=100 \
|
|
||||||
\
|
|
||||||
# poetry
|
|
||||||
# https://python-poetry.org/docs/configuration/#using-environment-variables
|
|
||||||
POETRY_VERSION=1.5.1 \
|
|
||||||
# make poetry install to this location
|
|
||||||
POETRY_HOME="/opt/poetry" \
|
|
||||||
# make poetry create the virtual environment in the project's root
|
|
||||||
# it gets named `.venv`
|
|
||||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
|
||||||
# do not ask any interactive question
|
|
||||||
POETRY_NO_INTERACTION=1 \
|
|
||||||
\
|
|
||||||
# paths
|
|
||||||
# this is where our requirements + virtual environment will live
|
|
||||||
PYSETUP_PATH="/opt/pysetup" \
|
|
||||||
VENV_PATH="/opt/pysetup/.venv"
|
|
||||||
|
|
||||||
|
RUN pip install poetry
|
||||||
|
|
||||||
# prepend poetry and venv to path
|
COPY ./ /app/
|
||||||
ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH"
|
|
||||||
|
|
||||||
|
RUN poetry install
|
||||||
|
|
||||||
# `builder-base` stage is used to build deps + create our virtual environment
|
ENTRYPOINT ["poetry", "run", "python", "-m", "storage_service"]
|
||||||
FROM python-base as builder-base
|
|
||||||
RUN apt-get update \
|
|
||||||
&& apt-get install --no-install-recommends -y \
|
|
||||||
# deps for installing poetry
|
|
||||||
curl \
|
|
||||||
# deps for building python deps
|
|
||||||
build-essential
|
|
||||||
|
|
||||||
# install poetry - respects $POETRY_VERSION & $POETRY_HOME
|
|
||||||
RUN curl -sSL https://install.python-poetry.org | python3 -
|
|
||||||
|
|
||||||
# copy project requirement files here to ensure they will be cached.
|
|
||||||
WORKDIR $PYSETUP_PATH
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
# install runtime deps - uses $POETRY_VIRTUALENVS_IN_PROJECT internally
|
|
||||||
RUN poetry install --no-dev
|
|
||||||
# `builder-base` stage is used to build deps + create our virtual environment
|
|
||||||
|
|
||||||
|
|
||||||
FROM python-base as production
|
|
||||||
|
|
||||||
COPY --from=builder-base $PYSETUP_PATH $PYSETUP_PATH
|
|
||||||
WORKDIR $PYSETUP_PATH
|
|
||||||
|
|
||||||
EXPOSE 5000-9000
|
|
||||||
|
|
||||||
# Run your app
|
|
||||||
CMD [ "./run-queue.sh" ]
|
|
||||||
|
|||||||
1215
poetry.lock
generated
1215
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,21 +9,35 @@ packages = [{include = "storage_service"}]
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.12"
|
python = "^3.12"
|
||||||
pillow = "^10.0.0"
|
pillow = "^10.0.0"
|
||||||
redis = "^5.0.3"
|
redis = "^5.0.4"
|
||||||
requests = "^2.31.0"
|
requests = "^2.32.1"
|
||||||
rq = "^1.15.1"
|
rq = "^1.16.1"
|
||||||
python-dotenv = "^1.0.0"
|
python-dotenv = "^1.0.0"
|
||||||
fastapi = "^0.110.1"
|
fastapi = "^0.111.0"
|
||||||
uvicorn = "^0.29.0"
|
uvicorn = "^0.29.0"
|
||||||
boto3 = "^1.28.21"
|
boto3 = "^1.34.109"
|
||||||
python-multipart = "^0.0.9"
|
python-multipart = "^0.0.9"
|
||||||
virustotal-python = "^1.0.2"
|
virustotal-python = "^1.0.2"
|
||||||
fastapi-utils = "^0.6.0"
|
fastapi-utils = "^0.6.0"
|
||||||
|
typing-inspect = "^0.9.0"
|
||||||
|
poethepoet = "^0.26.1"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
isort = "^5.12.0"
|
isort = "^5.12.0"
|
||||||
black = "^23.7.0"
|
black = "^23.7.0"
|
||||||
|
coverage = "^7.5.1"
|
||||||
|
pre-commit = "^3.7.1"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poe.tasks]
|
||||||
|
'run' = "python -m storage_service"
|
||||||
|
'run:queue' = "python -m storage_service --queue"
|
||||||
|
'run:dev' = "python -m storage_service --dev"
|
||||||
|
'create-hooks' = "bash .githooks/set-hooks.sh"
|
||||||
|
'test' = "coverage run -m unittest -v"
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
11
run-queue.sh
11
run-queue.sh
@@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
|
|
||||||
if [[ $1 == "--queue" || $1 == "-q" ]]; then
|
|
||||||
rq worker --with-scheduler
|
|
||||||
exit 0
|
|
||||||
else
|
|
||||||
python -m storage_service
|
|
||||||
fi
|
|
||||||
|
|
||||||
exec "$@"
|
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
from storage_service.config.config_allowed_origins import get_allowed_origins
|
||||||
|
from storage_service.controller import health_router, storage_router
|
||||||
|
from storage_service.utils.exception_handler import (
|
||||||
|
http_exception_handler,
|
||||||
|
validation_exception_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||||
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
|
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=get_allowed_origins(),
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(storage_router)
|
||||||
|
app.include_router(health_router)
|
||||||
|
|||||||
@@ -1,14 +1,42 @@
|
|||||||
from storage_service.config.config_server import get_config_server
|
from storage_service.config.config_server import get_config_server
|
||||||
from storage_service.controller import app
|
from storage_service.depends.depend_queue import dependency_queue_worker
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
def main():
|
|
||||||
config = get_config_server()
|
|
||||||
|
|
||||||
uvicorn.run(app, host=config["host"], port=config["port"])
|
def main(is_queue=False, is_dev=False):
|
||||||
|
if is_queue:
|
||||||
|
dependency_queue_worker().work(with_scheduler=True)
|
||||||
|
else:
|
||||||
|
config = {
|
||||||
|
**get_config_server(),
|
||||||
|
"reload": is_dev,
|
||||||
|
}
|
||||||
|
|
||||||
|
uvicorn.run("storage_service.__init__:app", **config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
parser = argparse.ArgumentParser(description="Storage Service")
|
||||||
|
parser.add_argument(
|
||||||
|
"-q",
|
||||||
|
"--queue",
|
||||||
|
dest="queue",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Runs the worker to process the queue",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-d",
|
||||||
|
"--dev",
|
||||||
|
dest="dev_mode",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Run the server in development mode.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args.queue, args.dev_mode)
|
||||||
|
|||||||
@@ -3,7 +3,12 @@ from dotenv import load_dotenv
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def get_virus_checker_api_key():
|
def get_virus_checker_api_key() -> str:
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
return {"api_key": os.environ.get("VIRUS_CHECKER_API_KEY")}
|
api_key = os.environ.get("VIRUS_CHECKER_API_KEY")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise RuntimeError("Virus Checker API Key not found")
|
||||||
|
|
||||||
|
return api_key
|
||||||
|
|||||||
@@ -1,20 +1,2 @@
|
|||||||
from storage_service.config.config_allowed_origins import get_allowed_origins
|
from .health_checker_controller import router as health_router
|
||||||
from storage_service.controller.health_checker_controller import health_router
|
from .storage_controller import router as storage_router
|
||||||
from storage_service.controller.storage_controller import s3_router
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=get_allowed_origins(),
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
app.include_router(s3_router)
|
|
||||||
app.include_router(health_router)
|
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
|
from storage_service.model.health_check.health_check_response import (
|
||||||
|
HealthCheckResponse,
|
||||||
|
)
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi_utils.cbv import cbv
|
from fastapi_utils.cbv import cbv
|
||||||
|
|
||||||
health_router = APIRouter()
|
router = APIRouter(tags=["health"])
|
||||||
|
|
||||||
|
|
||||||
@cbv(health_router)
|
@cbv(router)
|
||||||
class HealthCheckerController:
|
class HealthCheckerController:
|
||||||
@health_router.get("/health", status_code=200)
|
@router.get("/health", status_code=200)
|
||||||
def health(self) -> dict[str, str]:
|
def health(self) -> HealthCheckResponse:
|
||||||
return {"status": "healthy"}
|
return HealthCheckResponse(status="healthy")
|
||||||
|
|||||||
@@ -4,52 +4,61 @@ from storage_service.depends.depend_queue import dependency_queue
|
|||||||
from storage_service.depends.depend_s3_service import (
|
from storage_service.depends.depend_s3_service import (
|
||||||
dependency_storage_service,
|
dependency_storage_service,
|
||||||
)
|
)
|
||||||
|
from storage_service.model.storage.new_file_request import NewFileURLRequest
|
||||||
|
from storage_service.model.storage.process_file_request import (
|
||||||
|
ProcessFileRequest,
|
||||||
|
)
|
||||||
|
from storage_service.model.storage.signed_url_response import SignedUrlResponse
|
||||||
from storage_service.service.storage.storage_service import StorageService
|
from storage_service.service.storage.storage_service import StorageService
|
||||||
from storage_service.utils.enums.file_type import FileType
|
from storage_service.utils.exceptions.file_not_found_exception import (
|
||||||
from storage_service.utils.file_name_hash import file_name_hash
|
FileNotFoundException,
|
||||||
|
)
|
||||||
|
from storage_service.utils.file.file_hash_generator import generate_file_hash
|
||||||
from storage_service.worker.storage_file_worker import storage_file_worker
|
from storage_service.worker.storage_file_worker import storage_file_worker
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi_utils.cbv import cbv
|
from fastapi_utils.cbv import cbv
|
||||||
from rq import Queue
|
from rq import Queue
|
||||||
|
|
||||||
from typing import Annotated
|
router = APIRouter(tags=["storage"])
|
||||||
|
|
||||||
s3_router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
@cbv(s3_router)
|
@cbv(router)
|
||||||
class StorageController:
|
class StorageController:
|
||||||
queue: Queue = Depends(dependency_queue, use_cache=True)
|
queue: Queue = Depends(dependency_queue, use_cache=True)
|
||||||
storage_service: StorageService = Depends(
|
storage_service: StorageService = Depends(
|
||||||
dependency_storage_service, use_cache=True
|
dependency_storage_service, use_cache=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@s3_router.post("/file/", status_code=200)
|
@router.post("/file", status_code=200)
|
||||||
def new_file_url(
|
def new_file_url(self, new_file_request: NewFileURLRequest) -> SignedUrlResponse:
|
||||||
self,
|
hashed_file_name = generate_file_hash(
|
||||||
username: Annotated[str, Body(embed=True)],
|
new_file_request.file_key, new_file_request.file_postfix
|
||||||
file_postfix: Annotated[str, Body(embed=True)],
|
)
|
||||||
file_type: Annotated[FileType, Body(embed=True)],
|
|
||||||
) -> dict[str, str]:
|
|
||||||
return self.storage_service.get_temp_upload_link(
|
return self.storage_service.get_temp_upload_link(
|
||||||
file_name_hash(username, file_postfix), file_type
|
hashed_file_name, new_file_request.file_type
|
||||||
)
|
)
|
||||||
|
|
||||||
@s3_router.get("/file/", status_code=200)
|
@router.get("/file", status_code=200)
|
||||||
def file_url(self, username: str, file_postfix: str) -> dict[str, str | None]:
|
def file_url(self, file_key: str, file_postfix: str) -> SignedUrlResponse:
|
||||||
return self.storage_service.get_temp_read_link(
|
try:
|
||||||
file_name_hash(username, file_postfix)
|
return self.storage_service.get_temp_read_link(
|
||||||
|
generate_file_hash(file_key, file_postfix)
|
||||||
|
)
|
||||||
|
except Exception as _:
|
||||||
|
raise FileNotFoundException("File not found")
|
||||||
|
|
||||||
|
@router.delete("/file", status_code=204)
|
||||||
|
def delete_file(self, file_key: str, file_postfix: str):
|
||||||
|
return self.storage_service.delete_file(
|
||||||
|
generate_file_hash(file_key, file_postfix)
|
||||||
)
|
)
|
||||||
|
|
||||||
@s3_router.delete("/file/", status_code=204)
|
@router.post("/file/process", status_code=200)
|
||||||
def delete_file(self, username: str, file_postfix: str):
|
def process_file(self, process_file_request: ProcessFileRequest):
|
||||||
return self.storage_service.delete_file(file_name_hash(username, file_postfix))
|
self.queue.enqueue(
|
||||||
|
storage_file_worker,
|
||||||
@s3_router.post("/file/process", status_code=200)
|
process_file_request.file_key,
|
||||||
def process_file(
|
process_file_request.file_postfix,
|
||||||
self,
|
)
|
||||||
username: Annotated[str, Body(embed=True)],
|
|
||||||
file_postfix: Annotated[str, Body(embed=True)],
|
|
||||||
):
|
|
||||||
self.queue.enqueue(storage_file_worker, username, file_postfix)
|
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
from storage_service.config.config_redis import get_config_redis
|
from storage_service.config.config_redis import get_config_redis
|
||||||
|
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from rq import Queue
|
from rq import Queue, Worker
|
||||||
|
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
|
||||||
def dependency_queue():
|
@cache
|
||||||
return Queue(connection=Redis(**get_config_redis()))
|
def dependency_queue() -> Queue:
|
||||||
|
return Queue(name="default", connection=Redis(**get_config_redis()))
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def dependency_queue_worker() -> Worker:
|
||||||
|
return Worker(["default"], connection=Redis(**get_config_redis()))
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ from storage_service.service.storage.amazon_s3_service import AmazonS3Service
|
|||||||
from storage_service.service.storage.storage_service import StorageService
|
from storage_service.service.storage.storage_service import StorageService
|
||||||
from storage_service.utils.enums.storage_type import StorageType
|
from storage_service.utils.enums.storage_type import StorageType
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import botocore.client
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -14,6 +16,27 @@ def dependency_storage_service() -> StorageService:
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
if StorageType(os.environ["STORAGE_TYPE"]) == StorageType.S3_STORAGE:
|
if StorageType(os.environ["STORAGE_TYPE"]) == StorageType.S3_STORAGE:
|
||||||
return AmazonS3Service(**get_config_s3())
|
s3_config = get_config_s3()
|
||||||
|
|
||||||
|
if "aws_access_key_id" not in s3_config:
|
||||||
|
raise RuntimeError("Invalid S3 Config: Missing aws_access_key_id")
|
||||||
|
|
||||||
|
if "aws_secret_access_key" not in s3_config:
|
||||||
|
raise RuntimeError("Invalid S3 Config: Missing aws_secret_access_key")
|
||||||
|
|
||||||
|
if "region_name" not in s3_config:
|
||||||
|
raise RuntimeError("Invalid S3 Config: Missing region_name")
|
||||||
|
|
||||||
|
s3_client = boto3.client(
|
||||||
|
"s3",
|
||||||
|
region_name=s3_config["region_name"],
|
||||||
|
aws_access_key_id=s3_config["aws_access_key_id"],
|
||||||
|
aws_secret_access_key=s3_config["aws_secret_access_key"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return AmazonS3Service(
|
||||||
|
s3_client,
|
||||||
|
s3_config["bucket_name"],
|
||||||
|
)
|
||||||
|
|
||||||
raise RuntimeError("Invalid Storage Type")
|
raise RuntimeError("Invalid Storage Type")
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from storage_service.service.virus_checker.virus_total_service import (
|
|||||||
from storage_service.utils.enums.virus_checker_type import VirusCheckerType
|
from storage_service.utils.enums.virus_checker_type import VirusCheckerType
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from virustotal_python import Virustotal
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import cache
|
from functools import cache
|
||||||
@@ -19,13 +20,12 @@ from functools import cache
|
|||||||
def dependency_virus_checker_service() -> VirusCheckerService:
|
def dependency_virus_checker_service() -> VirusCheckerService:
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
virus_checker_config = get_virus_checker_api_key()
|
try:
|
||||||
|
type = VirusCheckerType(os.environ["VIRUS_CHECKER_TYPE"])
|
||||||
|
except ValueError:
|
||||||
|
raise RuntimeError("Invalid Virus Checker Type")
|
||||||
|
|
||||||
if not virus_checker_config["api_key"]:
|
match type:
|
||||||
raise RuntimeError("Virus Checker API Key not found")
|
case VirusCheckerType.TOTAL_VIRUS:
|
||||||
|
virus_checker = Virustotal(get_virus_checker_api_key())
|
||||||
virus_checker_type_var = os.environ.get("VIRUS_CHECKER_TYPE")
|
return VirusTotalService(virus_checker)
|
||||||
if VirusCheckerType(virus_checker_type_var) == VirusCheckerType.TOTAL_VIRUS:
|
|
||||||
return VirusTotalService(**get_virus_checker_api_key())
|
|
||||||
|
|
||||||
raise RuntimeError("Invalid Virus Checker Type")
|
|
||||||
|
|||||||
0
storage_service/model/health_check/__init__.py
Normal file
0
storage_service/model/health_check/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class HealthCheckResponse(BaseModel):
|
||||||
|
status: str
|
||||||
0
storage_service/model/storage/__init__.py
Normal file
0
storage_service/model/storage/__init__.py
Normal file
9
storage_service/model/storage/new_file_request.py
Normal file
9
storage_service/model/storage/new_file_request.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from storage_service.utils.enums.file_type import FileType
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class NewFileURLRequest(BaseModel):
|
||||||
|
file_key: str
|
||||||
|
file_postfix: str
|
||||||
|
file_type: FileType
|
||||||
6
storage_service/model/storage/process_file_request.py
Normal file
6
storage_service/model/storage/process_file_request.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessFileRequest(BaseModel):
|
||||||
|
file_key: str
|
||||||
|
file_postfix: str
|
||||||
6
storage_service/model/storage/signed_url_response.py
Normal file
6
storage_service/model/storage/signed_url_response.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class SignedUrlResponse(BaseModel):
|
||||||
|
signed_url: str
|
||||||
|
expires_in: int
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
from .amazon_s3_service import AmazonS3Service
|
||||||
|
from .storage_service import StorageService
|
||||||
|
|||||||
@@ -3,111 +3,128 @@ from __future__ import annotations
|
|||||||
from storage_service.depends.depend_virus_checker_service import (
|
from storage_service.depends.depend_virus_checker_service import (
|
||||||
dependency_virus_checker_service,
|
dependency_virus_checker_service,
|
||||||
)
|
)
|
||||||
|
from storage_service.model.storage.signed_url_response import SignedUrlResponse
|
||||||
from storage_service.service.storage.storage_service import StorageService
|
from storage_service.service.storage.storage_service import StorageService
|
||||||
|
from storage_service.service.virus_checker.virus_checker_service import (
|
||||||
|
VirusCheckerService,
|
||||||
|
)
|
||||||
from storage_service.utils.enums.file_type import FileType
|
from storage_service.utils.enums.file_type import FileType
|
||||||
from storage_service.utils.file_handler import FILE_HANDLER
|
|
||||||
|
|
||||||
import boto3
|
from botocore.client import BaseClient
|
||||||
|
|
||||||
import io
|
import io
|
||||||
from typing import Any
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AmazonS3Service(StorageService):
|
class AmazonS3Service(StorageService):
|
||||||
virus_checker_service = dependency_virus_checker_service()
|
virus_checker_service: VirusCheckerService
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
s3_client: BaseClient
|
||||||
super().__init__(**kwargs)
|
bucket_name: str
|
||||||
|
|
||||||
self.__validate_config(**kwargs)
|
expires_in: int = 3600
|
||||||
|
|
||||||
self.bucket_name = kwargs.get("bucket_name")
|
def __init__(
|
||||||
self.region_name = kwargs.get("region_name")
|
self,
|
||||||
|
s3_client: BaseClient,
|
||||||
|
bucket_name: str,
|
||||||
|
virus_checker_service=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.virus_checker_service = virus_checker_service
|
||||||
|
|
||||||
self.expires_in = kwargs.get("expires_in")
|
if s3_client is None:
|
||||||
|
raise RuntimeError("Invalid S3 Config: Missing s3_client")
|
||||||
|
self.s3_client = s3_client
|
||||||
|
|
||||||
self.s3 = boto3.client(
|
if bucket_name is None:
|
||||||
"s3",
|
raise RuntimeError("Invalid S3 Config: Missing bucket_name")
|
||||||
aws_access_key_id=kwargs.get("aws_access_key_id"),
|
self.bucket_name = bucket_name
|
||||||
aws_secret_access_key=kwargs.get("aws_secret_access_key"),
|
|
||||||
region_name=kwargs.get("region_name"),
|
if virus_checker_service is None:
|
||||||
|
self.virus_checker_service = dependency_virus_checker_service()
|
||||||
|
|
||||||
|
if "expires_in" in kwargs:
|
||||||
|
self.expires_in = kwargs["expires_in"]
|
||||||
|
|
||||||
|
def get_temp_upload_link(self, file_name, file_type: FileType) -> SignedUrlResponse:
|
||||||
|
return SignedUrlResponse(
|
||||||
|
signed_url=self._get_presigned_write_url(file_name, file_type),
|
||||||
|
expires_in=self.expires_in,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_temp_upload_link(
|
def get_temp_read_link(self, file_name) -> SignedUrlResponse:
|
||||||
self, file_name, file_type: FileType
|
return SignedUrlResponse(
|
||||||
) -> dict[str, str | Any]:
|
signed_url=self._get_presigned_read_url(file_name),
|
||||||
return {
|
expires_in=self.expires_in,
|
||||||
"presigned_url": self._get_presigned_write_url(file_name, file_type),
|
)
|
||||||
"file_key": self._get_object_url(file_name),
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_temp_read_link(self, file_name) -> dict[str, str | None]:
|
|
||||||
return {"presigned_url": self._get_presigned_read_url(file_name)}
|
|
||||||
|
|
||||||
def delete_file(self, file_name: str) -> None:
|
def delete_file(self, file_name: str) -> None:
|
||||||
self._delete_file(file_name)
|
self._delete_file(file_name)
|
||||||
|
|
||||||
def process_file(self, file_name: str, file_type: FileType = FileType.PNG) -> None:
|
def process_file(self, file_name: str, file_type: FileType = FileType.PNG) -> dict:
|
||||||
file_bytes = self._get_file_obj(file_name)
|
try:
|
||||||
|
file_bytes = self._get_file_obj(file_name)
|
||||||
|
except Exception as _:
|
||||||
|
raise FileNotFoundError("File not found")
|
||||||
|
|
||||||
if not self.virus_checker_service.check_virus(file_bytes):
|
if not self.virus_checker_service.check_virus(file_bytes):
|
||||||
self._delete_file(file_name)
|
raise ValueError("Virus Detected")
|
||||||
|
|
||||||
handler = FILE_HANDLER[file_type]["handler"]
|
try:
|
||||||
|
old_size = file_bytes.getbuffer().nbytes
|
||||||
|
|
||||||
self._upload_file(file_name, handler(file_bytes))
|
file_bytes = file_type.get_validator()(file_bytes)
|
||||||
|
|
||||||
def _get_object_url(self, file_name: str) -> str:
|
new_size = file_bytes.getbuffer().nbytes
|
||||||
return f"https://{self.bucket_name}.s3.{self.region_name}.amazonaws.com/{file_name}"
|
except Exception as _:
|
||||||
|
raise RuntimeError("Error Processing")
|
||||||
|
|
||||||
|
self._upload_file(file_name, file_bytes)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"previous_size": old_size,
|
||||||
|
"current_size": new_size,
|
||||||
|
}
|
||||||
|
|
||||||
def _get_presigned_write_url(self, file_name, file_type: FileType) -> str:
|
def _get_presigned_write_url(self, file_name, file_type: FileType) -> str:
|
||||||
return self.s3.generate_presigned_url(
|
return self.s3_client.generate_presigned_url(
|
||||||
"put_object",
|
"put_object",
|
||||||
Params={
|
Params={
|
||||||
"Bucket": self.bucket_name,
|
"Bucket": self.bucket_name,
|
||||||
"Key": file_name,
|
"Key": file_name,
|
||||||
"ContentType": FILE_HANDLER[file_type]["content_type"],
|
"ContentType": file_type.get_content_type(),
|
||||||
},
|
},
|
||||||
ExpiresIn=self.expires_in,
|
ExpiresIn=self.expires_in,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_presigned_read_url(self, file_name) -> str | None:
|
def _get_presigned_read_url(self, file_name) -> str | None:
|
||||||
result = self.s3.list_objects(Bucket=self.bucket_name, Prefix=file_name)
|
result = self.s3_client.list_objects(Bucket=self.bucket_name, Prefix=file_name)
|
||||||
|
|
||||||
if "Contents" in result and file_name in map(
|
if "Contents" in result and file_name in map(
|
||||||
lambda x: x["Key"], result["Contents"]
|
lambda x: x["Key"], result["Contents"]
|
||||||
):
|
):
|
||||||
return self.s3.generate_presigned_url(
|
return self.s3_client.generate_presigned_url(
|
||||||
"get_object",
|
"get_object",
|
||||||
Params={"Bucket": self.bucket_name, "Key": file_name},
|
Params={"Bucket": self.bucket_name, "Key": file_name},
|
||||||
ExpiresIn=self.expires_in,
|
ExpiresIn=self.expires_in,
|
||||||
)
|
)
|
||||||
return None
|
|
||||||
|
raise FileNotFoundError("File not found")
|
||||||
|
|
||||||
def _get_file_obj(self, file_name: str) -> io.BytesIO:
|
def _get_file_obj(self, file_name: str) -> io.BytesIO:
|
||||||
return io.BytesIO(
|
return io.BytesIO(
|
||||||
self.s3.get_object(Bucket=self.bucket_name, Key=file_name)["Body"].read()
|
self.s3_client.get_object(Bucket=self.bucket_name, Key=file_name)[
|
||||||
|
"Body"
|
||||||
|
].read()
|
||||||
)
|
)
|
||||||
|
|
||||||
def _upload_file(self, file_name: str, file_bytes: io.BytesIO) -> None:
|
def _upload_file(self, file_name: str, file_bytes: io.BytesIO) -> None:
|
||||||
self.s3.upload_fileobj(file_bytes, Bucket=self.bucket_name, Key=file_name)
|
self.s3_client.upload_fileobj(
|
||||||
|
file_bytes, Bucket=self.bucket_name, Key=file_name
|
||||||
|
)
|
||||||
|
|
||||||
def _delete_file(self, file_name: str) -> None:
|
def _delete_file(self, file_name: str) -> None:
|
||||||
self.s3.delete_object(Bucket=self.bucket_name, Key=file_name)
|
self.s3_client.delete_object(Bucket=self.bucket_name, Key=file_name)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __validate_config(**kwargs):
|
|
||||||
if not kwargs.get("bucket_name"):
|
|
||||||
raise RuntimeError("bucket_name is required")
|
|
||||||
|
|
||||||
if not kwargs.get("aws_access_key_id"):
|
|
||||||
raise RuntimeError("aws_access_key_id is required")
|
|
||||||
|
|
||||||
if not kwargs.get("aws_secret_access_key"):
|
|
||||||
raise RuntimeError("aws_secret_access_key is required")
|
|
||||||
|
|
||||||
if not kwargs.get("region_name"):
|
|
||||||
raise RuntimeError("region_name is required")
|
|
||||||
|
|
||||||
if not kwargs.get("bucket_name"):
|
|
||||||
raise RuntimeError("bucket_name is required")
|
|
||||||
|
|||||||
@@ -1,23 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from storage_service.model.storage.signed_url_response import SignedUrlResponse
|
||||||
from storage_service.utils.enums.file_type import FileType
|
from storage_service.utils.enums.file_type import FileType
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class StorageService(ABC):
|
class StorageService(ABC):
|
||||||
def __init__(self, **kwargs):
|
@abstractmethod
|
||||||
|
def get_temp_upload_link(self, file_name, file_type: FileType) -> SignedUrlResponse:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_temp_upload_link(
|
def get_temp_read_link(self, file_name) -> SignedUrlResponse:
|
||||||
self, file_name, file_type: FileType
|
|
||||||
) -> dict[str, str | Any]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_temp_read_link(self, file_name) -> dict[str, str | None]:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -25,5 +20,5 @@ class StorageService(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process_file(self, file_name: str, file_type: FileType) -> None:
|
def process_file(self, file_name: str, file_type: FileType) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -8,31 +8,33 @@ from io import BytesIO
|
|||||||
|
|
||||||
|
|
||||||
class VirusTotalService(VirusCheckerService):
|
class VirusTotalService(VirusCheckerService):
|
||||||
def __init__(self, api_key: str):
|
virus_checker: Virustotal
|
||||||
self.api_key = api_key
|
|
||||||
|
def __init__(self, virus_checker: Virustotal):
|
||||||
|
self.virus_checker = virus_checker
|
||||||
|
|
||||||
def check_virus(self, file_data: BytesIO) -> bool:
|
def check_virus(self, file_data: BytesIO) -> bool:
|
||||||
|
file_id = self._upload_file(file_data)
|
||||||
|
file_attributes = self._get_analysis(file_id)
|
||||||
|
|
||||||
|
return self._is_valid_file(file_attributes)
|
||||||
|
|
||||||
|
def _upload_file(self, file_data: BytesIO) -> str:
|
||||||
files = {"file": ("image_file", file_data)}
|
files = {"file": ("image_file", file_data)}
|
||||||
|
|
||||||
with Virustotal(self.api_key) as vtotal:
|
resp = self.virus_checker.request("files", files=files, method="POST")
|
||||||
resp = vtotal.request("files", files=files, method="POST")
|
|
||||||
|
|
||||||
file_attributes = self._get_analysis(resp.json()["data"]["id"])
|
return resp.data["id"]
|
||||||
|
|
||||||
return self._is_valid_file(file_attributes["data"]["attributes"]["stats"])
|
|
||||||
|
|
||||||
def _get_analysis(self, file_id: str) -> dict:
|
def _get_analysis(self, file_id: str) -> dict:
|
||||||
with Virustotal(self.api_key) as vtotal:
|
resp = self.virus_checker.request(f"analyses/{file_id}")
|
||||||
resp = vtotal.request(f"analyses/{file_id}")
|
|
||||||
|
|
||||||
return resp.json()
|
return resp.json()["data"]["attributes"]["stats"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_valid_file(file_stats: dict) -> bool:
|
def _is_valid_file(file_stats: dict) -> bool:
|
||||||
if "malicious" in file_stats and file_stats["malicious"] > 0:
|
match file_stats:
|
||||||
return False
|
case {"malicious": 0, "suspicious": 0, "harmless": 0}:
|
||||||
|
return True
|
||||||
if "suspicious" in file_stats and file_stats["suspicious"] > 0:
|
case _:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
|
||||||
|
|||||||
@@ -1,6 +1,26 @@
|
|||||||
|
from storage_service.utils.file.validators import image_validator
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
class FileType(Enum):
|
class FileType(Enum):
|
||||||
PNG = "png"
|
PNG = "png"
|
||||||
JPEG = "jpeg"
|
JPEG = "jpeg"
|
||||||
|
|
||||||
|
def get_content_type(self) -> str:
|
||||||
|
match self:
|
||||||
|
case FileType.PNG:
|
||||||
|
return "image/png"
|
||||||
|
case FileType.JPEG:
|
||||||
|
return "image/jpeg"
|
||||||
|
case _:
|
||||||
|
raise ValueError("File Type Not Implemented")
|
||||||
|
|
||||||
|
def get_validator(self) -> Callable[[BytesIO], BytesIO]:
|
||||||
|
match self:
|
||||||
|
case FileType.PNG | FileType.JPEG:
|
||||||
|
return image_validator
|
||||||
|
case _:
|
||||||
|
raise ValueError("File Type Not Implemented")
|
||||||
|
|||||||
2
storage_service/utils/exception_handler/__init__.py
Normal file
2
storage_service/utils/exception_handler/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .http_exception_handler import http_exception_handler
|
||||||
|
from .validation_exception_handler import validation_exception_handler
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
from starlette.exceptions import HTTPException
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"message": exc.detail,
|
||||||
|
"status_code": exc.status_code,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from starlette import status
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
status_code = status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status_code,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"details": {
|
||||||
|
"body": exc.body,
|
||||||
|
"errors": exc.errors(),
|
||||||
|
},
|
||||||
|
"status_code": status_code,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
0
storage_service/utils/exceptions/__init__.py
Normal file
0
storage_service/utils/exceptions/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
|
||||||
|
class FileNotFoundException(HTTPException):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(status.HTTP_404_NOT_FOUND, detail=message)
|
||||||
0
storage_service/utils/file/__init__.py
Normal file
0
storage_service/utils/file/__init__.py
Normal file
9
storage_service/utils/file/file_hash_generator.py
Normal file
9
storage_service/utils/file/file_hash_generator.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import base64
|
||||||
|
from hashlib import md5
|
||||||
|
|
||||||
|
|
||||||
|
def generate_file_hash(file_key: str, file_postfix: str) -> str:
|
||||||
|
hashed_file_key = md5(file_key.encode("utf-8")).digest()
|
||||||
|
hashed_file_key = base64.b64encode(hashed_file_key).decode()
|
||||||
|
|
||||||
|
return f"{hashed_file_key}_{file_postfix}"
|
||||||
1
storage_service/utils/file/validators/__init__.py
Normal file
1
storage_service/utils/file/validators/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .image_handler import image_validator
|
||||||
@@ -3,10 +3,10 @@ from PIL import Image
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
|
|
||||||
def image_handler(file_bytes: io.BytesIO) -> io.BytesIO:
|
def image_validator(file_bytes: io.BytesIO) -> io.BytesIO:
|
||||||
img = Image.open(file_bytes)
|
img = Image.open(file_bytes)
|
||||||
|
|
||||||
img.thumbnail((320, 320))
|
img.thumbnail((180, 180))
|
||||||
|
|
||||||
data = list(img.getdata())
|
data = list(img.getdata())
|
||||||
image_without_exif = Image.new(img.mode, img.size)
|
image_without_exif = Image.new(img.mode, img.size)
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
from storage_service.utils.enums.file_type import FileType
|
|
||||||
from storage_service.utils.file_handler.handlers.image_handler import (
|
|
||||||
image_handler,
|
|
||||||
)
|
|
||||||
|
|
||||||
FILE_HANDLER = {
|
|
||||||
FileType.PNG: {"content_type": "image/png", "handler": image_handler},
|
|
||||||
FileType.JPEG: {"content_type": "image/jpeg", "handler": image_handler},
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
import base64
|
|
||||||
from hashlib import md5
|
|
||||||
|
|
||||||
|
|
||||||
def file_name_hash(username: str, file_postfix: str) -> str:
|
|
||||||
hashed_username = md5(username.encode("utf-8")).digest()
|
|
||||||
hashed_username = base64.b64encode(hashed_username).decode()
|
|
||||||
|
|
||||||
return f"{hashed_username}_{file_postfix}"
|
|
||||||
@@ -1,9 +1,29 @@
|
|||||||
from storage_service.depends.depend_s3_service import (
|
from storage_service.depends.depend_s3_service import (
|
||||||
dependency_storage_service,
|
dependency_storage_service,
|
||||||
)
|
)
|
||||||
from storage_service.utils.enums.file_type import FileType
|
from storage_service.utils.file.file_hash_generator import generate_file_hash
|
||||||
from storage_service.utils.file_name_hash import file_name_hash
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def storage_file_worker(username: str, file_postfix: str) -> None:
|
def storage_file_worker(username: str, file_postfix: str) -> None:
|
||||||
dependency_storage_service().process_file(file_name_hash(username, file_postfix))
|
storage_service = dependency_storage_service()
|
||||||
|
|
||||||
|
file_name = generate_file_hash(username, file_postfix)
|
||||||
|
try:
|
||||||
|
stats = storage_service.process_file(file_name)
|
||||||
|
|
||||||
|
previous_size_kb = stats["previous_size"] / 1_000
|
||||||
|
current_size_kb = stats["current_size"] / 1_000
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"File processed: {file_name} - "
|
||||||
|
f"Previous Size: {previous_size_kb}kb - "
|
||||||
|
f"New Size: {current_size_kb}kb"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing file: {e}." f" Deleting file: {file_name}.")
|
||||||
|
|
||||||
|
storage_service.delete_file(file_name)
|
||||||
|
|||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/storage_service/__init__.py
Normal file
0
tests/storage_service/__init__.py
Normal file
140
tests/storage_service/test_amazon_s3_service.py
Normal file
140
tests/storage_service/test_amazon_s3_service.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
from unittest import TestCase
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from storage_service.service.storage import AmazonS3Service
|
||||||
|
from storage_service.utils.enums.file_type import FileType
|
||||||
|
|
||||||
|
|
||||||
|
class TestAmazonS3Service(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.s3_client_mock = Mock()
|
||||||
|
self.virus_checker_service_mock = Mock()
|
||||||
|
|
||||||
|
def test_get_temp_upload_link(self):
|
||||||
|
self.s3_client_mock.generate_presigned_url.return_value = "https://test.com"
|
||||||
|
|
||||||
|
storage_service = AmazonS3Service(
|
||||||
|
s3_client=self.s3_client_mock,
|
||||||
|
bucket_name="test_bucket",
|
||||||
|
virus_checker_service=self.virus_checker_service_mock
|
||||||
|
)
|
||||||
|
|
||||||
|
response = storage_service.get_temp_upload_link("test_file", FileType.JPEG)
|
||||||
|
|
||||||
|
self.assertEqual(response.signed_url, "https://test.com")
|
||||||
|
self.assertEqual(response.expires_in, 3600)
|
||||||
|
|
||||||
|
self.s3_client_mock.generate_presigned_url.assert_called_once_with(
|
||||||
|
"put_object",
|
||||||
|
Params={
|
||||||
|
"Bucket": "test_bucket",
|
||||||
|
"Key": "test_file",
|
||||||
|
"ContentType": "image/jpeg",
|
||||||
|
},
|
||||||
|
ExpiresIn=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_temp_read_link(self):
|
||||||
|
self.s3_client_mock.generate_presigned_url.return_value = "https://test.com"
|
||||||
|
self.s3_client_mock.list_objects.return_value = {
|
||||||
|
"Contents": [
|
||||||
|
{
|
||||||
|
"Key": "test_file"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
storage_service = AmazonS3Service(
|
||||||
|
s3_client=self.s3_client_mock,
|
||||||
|
bucket_name="test_bucket",
|
||||||
|
virus_checker_service=self.virus_checker_service_mock
|
||||||
|
)
|
||||||
|
|
||||||
|
response = storage_service.get_temp_read_link("test_file")
|
||||||
|
|
||||||
|
self.assertEqual(response.signed_url, "https://test.com")
|
||||||
|
self.assertEqual(response.expires_in, 3600)
|
||||||
|
|
||||||
|
self.s3_client_mock.generate_presigned_url.assert_called_once_with(
|
||||||
|
"get_object",
|
||||||
|
Params={
|
||||||
|
"Bucket": "test_bucket",
|
||||||
|
"Key": "test_file"
|
||||||
|
},
|
||||||
|
ExpiresIn=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_delete_file(self):
|
||||||
|
storage_service = AmazonS3Service(
|
||||||
|
s3_client=self.s3_client_mock,
|
||||||
|
bucket_name="test_bucket",
|
||||||
|
virus_checker_service=self.virus_checker_service_mock
|
||||||
|
)
|
||||||
|
|
||||||
|
storage_service.delete_file("test_file")
|
||||||
|
|
||||||
|
self.s3_client_mock.delete_object.assert_called_once_with(
|
||||||
|
Bucket="test_bucket",
|
||||||
|
Key="test_file"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_process_file_if_file_invalid(self):
|
||||||
|
mock_body = Mock()
|
||||||
|
mock_body.read.return_value = b"test_file"
|
||||||
|
self.s3_client_mock.get_object.return_value = {
|
||||||
|
"Body": mock_body
|
||||||
|
}
|
||||||
|
self.virus_checker_service_mock.check_virus.return_value = True
|
||||||
|
|
||||||
|
storage_service = AmazonS3Service(
|
||||||
|
s3_client=self.s3_client_mock,
|
||||||
|
bucket_name="test_bucket",
|
||||||
|
virus_checker_service=self.virus_checker_service_mock
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
storage_service.process_file("test_file", FileType.JPEG)
|
||||||
|
|
||||||
|
def test_process_file_if_file_is_virus(self):
|
||||||
|
mock_body = Mock()
|
||||||
|
mock_body.read.return_value = b"test_file"
|
||||||
|
self.s3_client_mock.get_object.return_value = {
|
||||||
|
"Body": mock_body
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_file_type = Mock()
|
||||||
|
mock_file_type.get_validator.return_value = lambda x: x
|
||||||
|
mock_file_type.get_content_type.return_value = "image/fake"
|
||||||
|
|
||||||
|
self.virus_checker_service_mock.check_virus.return_value = False
|
||||||
|
|
||||||
|
storage_service = AmazonS3Service(
|
||||||
|
s3_client=self.s3_client_mock,
|
||||||
|
bucket_name="test_bucket",
|
||||||
|
virus_checker_service=self.virus_checker_service_mock
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
storage_service.process_file("test_file", mock_file_type)
|
||||||
|
|
||||||
|
def test_process_file(self):
|
||||||
|
mock_body = Mock()
|
||||||
|
mock_body.read.return_value = b"test_file"
|
||||||
|
self.s3_client_mock.get_object.return_value = {
|
||||||
|
"Body": mock_body
|
||||||
|
}
|
||||||
|
self.virus_checker_service_mock.check_virus.return_value = True
|
||||||
|
|
||||||
|
mock_file_type = Mock()
|
||||||
|
mock_file_type.get_validator.return_value = lambda x: x
|
||||||
|
mock_file_type.get_content_type.return_value = "image/fake"
|
||||||
|
|
||||||
|
storage_service = AmazonS3Service(
|
||||||
|
s3_client=self.s3_client_mock,
|
||||||
|
bucket_name="test_bucket",
|
||||||
|
virus_checker_service=self.virus_checker_service_mock
|
||||||
|
)
|
||||||
|
|
||||||
|
storage_service.process_file("test_file", mock_file_type)
|
||||||
|
|
||||||
|
self.s3_client_mock.upload_fileobj.assert_called()
|
||||||
0
tests/virus_checker_service/__init__.py
Normal file
0
tests/virus_checker_service/__init__.py
Normal file
33
tests/virus_checker_service/test_virus_total_service.py
Normal file
33
tests/virus_checker_service/test_virus_total_service.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
from unittest import TestCase
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from storage_service.service.virus_checker.virus_total_service import VirusTotalService
|
||||||
|
|
||||||
|
|
||||||
|
class TestVirusTotalService(TestCase):
|
||||||
|
def test_check_virus_invalid(self):
|
||||||
|
mock_virus_checker = Mock()
|
||||||
|
mock_virus_checker.request.side_effect = [
|
||||||
|
Mock(data={"id": "file_id"}),
|
||||||
|
Mock(json=Mock(return_value={"data": {"attributes": {"stats": {"malicious": 1, "suspicious": 1, "harmless": 1}}}})),
|
||||||
|
]
|
||||||
|
|
||||||
|
virus_total_service = VirusTotalService(mock_virus_checker)
|
||||||
|
|
||||||
|
result = virus_total_service.check_virus(BytesIO(b"file_data"))
|
||||||
|
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_check_virus_valid(self):
|
||||||
|
mock_virus_checker = Mock()
|
||||||
|
mock_virus_checker.request.side_effect = [
|
||||||
|
Mock(data={"id": "file_id"}),
|
||||||
|
Mock(json=Mock(return_value={"data": {"attributes": {"stats": {"malicious": 0, "suspicious": 0, "harmless": 0}}}})),
|
||||||
|
]
|
||||||
|
|
||||||
|
virus_total_service = VirusTotalService(mock_virus_checker)
|
||||||
|
|
||||||
|
result = virus_total_service.check_virus(BytesIO(b"file_data"))
|
||||||
|
|
||||||
|
self.assertTrue(result)
|
||||||
Reference in New Issue
Block a user