diff --git a/storage_service/controller/storage_controller.py b/storage_service/controller/storage_controller.py index 2c4793c..31579b4 100644 --- a/storage_service/controller/storage_controller.py +++ b/storage_service/controller/storage_controller.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from storage_service.depends.depend_queue import dependency_queue from storage_service.depends.depend_s3_service import ( dependency_storage_service, @@ -36,7 +38,7 @@ class StorageController: ) @s3_router.get("/file_url/", status_code=200) - def file_url(self, username: str, file_postfix: str) -> dict[str, str]: + def file_url(self, username: str, file_postfix: str) -> dict[str, str | None]: return self.storage_service.get_temp_read_link( file_name_hash(username, file_postfix) ) diff --git a/storage_service/service/amazon_s3_service.py b/storage_service/service/amazon_s3_service.py index a23bbb2..9a4f8e3 100644 --- a/storage_service/service/amazon_s3_service.py +++ b/storage_service/service/amazon_s3_service.py @@ -36,7 +36,7 @@ class AmazonS3Service(StorageService): "file_key": self._get_object_url(file_name), } - def get_temp_read_link(self, file_name) -> dict[str, str | Any]: + def get_temp_read_link(self, file_name) -> dict[str, str | None]: return {"presigned_url": self._get_presigned_read_url(file_name)} def process_file(self, file_name: str, file_type: FileType = FileType.PNG) -> None: @@ -59,12 +59,16 @@ class AmazonS3Service(StorageService): ExpiresIn=self.expires_in, ) - def _get_presigned_read_url(self, file_name) -> str: - return self.s3.generate_presigned_url( - "get_object", - Params={"Bucket": self.bucket_name, "Key": file_name}, - ExpiresIn=self.expires_in, - ) + def _get_presigned_read_url(self, file_name) -> str | None: + result = self.s3.list_objects(Bucket=self.bucket_name, Prefix=file_name) + + if file_name in map(lambda x: x["Key"], result["Contents"]): + return self.s3.generate_presigned_url( + "get_object", + Params={"Bucket": self.bucket_name, "Key": file_name}, + ExpiresIn=self.expires_in, + ) + return None def _get_file_obj(self, file_name: str) -> io.BytesIO: return io.BytesIO( diff --git a/storage_service/service/storage_service.py b/storage_service/service/storage_service.py index e065dca..5967962 100644 --- a/storage_service/service/storage_service.py +++ b/storage_service/service/storage_service.py @@ -17,7 +17,7 @@ class StorageService(ABC): pass @abstractmethod - def get_temp_read_link(self, file_name) -> dict[str, str | Any]: + def get_temp_read_link(self, file_name) -> dict[str, str | None]: pass @abstractmethod