from abc import ABCMeta
from enum import Enum
from functools import cached_property
from types import TracebackType
from typing import Any, ClassVar, Protocol, TypeVar
from typing_extensions import override
from httpx import URL, AsyncClient, Client, Request, Response
from .logger import getLogger
from .metrics import ApiUsage, parse_api_usage
from .exceptions import raise_for_status
from .auth import (
SalesforceAuth,
SalesforceLogin,
SalesforceToken,
TokenRefreshCallback,
)
from .apimodels import ApiVersion, UserInfo, OrgLimits
LOGGER = getLogger("client")
_T = TypeVar("_T")
_SCB = TypeVar("_SCB", bound="SalesforceClientBase")
[docs]
class OrgType(Enum):
PRODUCTION = "Production"
SCRATCH = "Scratch"
SANDBOX = "Sandbox"
DEVELOPER = "Developer"
[docs]
class ClientBaseProto(Protocol):
_base_url: URL
def _enforce_trailing_slash(self, url: URL) -> URL: ...
[docs]
def build_request(self, method: str, url: str) -> Request: ...
[docs]
class SalesforceClientBase(ClientBaseProto, metaclass=ABCMeta):
token_refresh_callback: TokenRefreshCallback | None = None
api_version: ApiVersion | None = None
_versions: dict[float, ApiVersion] | None = None
_userinfo: UserInfo | None = None
api_usage: ApiUsage | None = None
connection_name: str
DEFAULT_CONNECTION_NAME: ClassVar[str] = "default"
[docs]
def register(
self: _SCB,
api_version: ApiVersion | int | float | str | None = None,
connection_name: str = DEFAULT_CONNECTION_NAME,
):
if api_version is not None:
self.api_version = ApiVersion.lazy_build(api_version)
self.connection_name = connection_name
type(self).register_connection(connection_name, self)
def __init_subclass__(cls) -> None:
super().__init_subclass__()
cls._connections: dict[str, "SalesforceClientBase"] = {}
[docs]
def handle_token_refresh(self, token: SalesforceToken):
self._derive_base_url(token)
if self.token_refresh_callback:
self.token_refresh_callback(token)
[docs]
def set_token_refresh_callback(self, callback: TokenRefreshCallback):
self.token_refresh_callback = callback
def _derive_base_url(self, session: SalesforceToken):
self._base_url = self._enforce_trailing_slash(session.instance)
@property
def org_type(self) -> OrgType:
if not self._base_url:
raise ValueError("Base URL is not set on the client.")
if ".scratch." in self._base_url.host.lower():
return OrgType.SCRATCH
elif ".sandbox." in self._base_url.host.lower():
return OrgType.SANDBOX
elif self._base_url.host.lower().split(".", 1)[0].endswith("-dev-ed"):
return OrgType.DEVELOPER
else:
return OrgType.PRODUCTION
@property
def data_url(self):
if not self.api_version:
assert hasattr(self, "_versions") and self._versions, ""
self.api_version = self._versions[max(self._versions)]
return self.api_version.url
def _userinfo_request(self):
return self.build_request("GET", "/services/oauth2/userinfo")
def _versions_request(self):
return self.build_request("GET", "/services/data")
@property
def sobjects_url(self):
return f"{self.data_url}/sobjects"
[docs]
def composite_sobjects_url(self, sobject: str | None = None):
url = f"{self.data_url}/composite/sobjects"
if sobject:
url += "/" + sobject
return url
@property
def tooling_url(self):
return f"{self.data_url}/tooling"
@property
def tooling_sobjects_url(self):
return f"{self.data_url}/tooling"
@property
def metadata_url(self):
return f"{self.data_url}/metadata"
[docs]
@classmethod
def get_connection(cls: type[_SCB], name: str | None = None) -> _SCB:
return cls._connections[name or cls.DEFAULT_CONNECTION_NAME] # pyright: ignore[reportReturnType]
[docs]
@classmethod
def register_connection(cls: type[_SCB], connection_name: str, instance: _SCB):
if connection_name in cls._connections:
raise KeyError(
f"SalesforceClient connection '{connection_name}' has already been registered."
)
cls._connections[connection_name] = instance
[docs]
@classmethod
def unregister_connection(cls: type[_SCB], name_or_instance: str | _SCB):
if isinstance(name_or_instance, str):
names_to_unregister = [name_or_instance]
else:
names_to_unregister = [
name
for name, instance in cls._connections.items()
if instance is name_or_instance
]
for name in names_to_unregister:
if name in cls._connections:
del cls._connections[name]
[docs]
class AsyncSalesforceClient(AsyncClient, SalesforceClientBase):
_auth: SalesforceAuth
token_refresh_callback: TokenRefreshCallback | None
[docs]
def __init__(
self,
login: SalesforceLogin | None = None,
token: SalesforceToken | None = None,
token_refresh_callback: TokenRefreshCallback | None = None,
api_version: ApiVersion | int | float | str | None = None,
connection_name: str = SalesforceClientBase.DEFAULT_CONNECTION_NAME,
):
assert login or token, (
"Either auth or session parameters are required.\n"
"Both are permitted simultaneously."
)
super().__init__(
auth=SalesforceAuth(login, token, self.handle_token_refresh),
headers={"Accept": "application/json"},
)
self.register(api_version, connection_name)
if token:
self._derive_base_url(token)
self.token_refresh_callback = token_refresh_callback
@override
async def __aenter__(self):
_ = await super().__aenter__()
try:
self._userinfo = UserInfo(
**(await self.send(self._userinfo_request())).json()
)
if self.api_version:
self.api_version = (await self.versions())[self.api_version.version]
else:
self.api_version = (await self.versions())[max(await self.versions())]
LOGGER.info(
"Logged into %s as %s (%s)",
self.base_url,
self._userinfo.name,
self._userinfo.preferred_username,
)
except Exception as e:
await self.__aexit__(type(e), e, e.__traceback__)
raise
return self
[docs]
@override
async def aclose(self):
self.unregister_connection(self.connection_name)
self.unregister_connection(self)
return await super().aclose()
@override
async def __aexit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
self.unregister_connection(self.connection_name)
self.unregister_connection(self)
return await super().__aexit__(exc_type, exc_value, traceback)
[docs]
@override
async def request(
self, method: str, url: URL | str, resource_name: str = "", **kwargs: Any
) -> Response:
response = await super().request(method, url, **kwargs)
raise_for_status(response, resource_name)
if sforce_limit_info := response.headers.get("Sforce-Limit-Info"):
self.api_usage = parse_api_usage(sforce_limit_info)
return response
[docs]
async def versions(self) -> dict[float, ApiVersion]:
"""
Returns a dictionary of API versions available in the org asynchronously.
https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/dome_versions.htm
Returns:
dict[float, ApiVersion]: Dictionary of available API versions
"""
response = await self.request("GET", "/services/data")
versions_data = response.json()
return {
float(version["version"]): ApiVersion(
float(version["version"]), version["label"], version["url"]
)
for version in versions_data
}
[docs]
class SalesforceClient(Client, SalesforceClientBase):
token_refresh_callback: TokenRefreshCallback | None
connection_name: str
_auth: SalesforceAuth
[docs]
def __init__(
self,
connection_name: str = SalesforceClientBase.DEFAULT_CONNECTION_NAME,
login: SalesforceLogin | None = None,
token: SalesforceToken | None = None,
token_refresh_callback: TokenRefreshCallback | None = None,
api_version: ApiVersion | int | float | str | None = None,
**kwargs: Any,
):
assert login or token, (
"Either auth or session parameters are required.\n"
"Both are permitted simultaneously."
)
auth = SalesforceAuth(login, token, self.handle_token_refresh)
super().__init__(auth=auth, **kwargs)
self.register(connection_name=connection_name, api_version=api_version)
if token:
self._derive_base_url(token)
self.token_refresh_callback = token_refresh_callback
self.connection_name = connection_name
@override
def __str__(self):
if not (isinstance(self.auth, SalesforceAuth) and self.auth.token is not None):
return f"{type(self).__name__} ({self.connection_name})"
return (
f"{type(self).__name__} ({self.connection_name}) -> "
f"{self.auth.token.instance.host} as {(_ui := self._userinfo) and _ui.preferred_username}"
)
[docs]
def handle_async_clone_token_refresh(self, token: SalesforceToken):
self._auth.token = token
@override
def __enter__(self):
_ = Client.__enter__(self)
try:
self._userinfo = UserInfo(**self.send(self._userinfo_request()).json())
if _av := getattr(self, "api_version", None):
self.api_version = self.versions[_av.version]
else:
self.api_version = self.versions[max(self.versions)]
LOGGER.info(
"Logged into %s as %s (%s)",
self.base_url,
self._userinfo.name,
self._userinfo.preferred_username,
)
except Exception as e:
self.__exit__(type(e), e, e.__traceback__)
raise
return self
@override
def __exit__(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
):
self.unregister_connection(self.connection_name)
self.unregister_connection(self)
return super().__exit__(exc_type, exc_value, traceback)
[docs]
@override
def close(self):
self.unregister_connection(self.connection_name)
self.unregister_connection(self)
return super().close()
[docs]
@override
def request(
self,
method: str,
url: URL | str,
resource_name: str = "",
response_status_raise: bool = True,
**kwargs: Any,
) -> Response:
response = super().request(method, url, **kwargs)
if response_status_raise:
raise_for_status(response, resource_name)
sforce_limit_info: str | None = response.headers.get("Sforce-Limit-Info")
if sforce_limit_info:
self.api_usage = parse_api_usage(sforce_limit_info)
return response
@cached_property
def versions(self) -> dict[float, ApiVersion]:
"""
Returns a dictionary of API versions available in the org.
Returns:
list[ApiVersion]: List of available API versions
"""
response = self.request("GET", "/services/data")
versions_data = response.json()
return {
(f_ver := float(version["version"])): ApiVersion(
f_ver, version["label"], version["url"]
)
for version in versions_data
}
[docs]
def limits(self):
"""
Returns a dictionary of API versions available in the org.
Returns:
OrgLimits: dict-like object of available limits
"""
return OrgLimits(**self.get(self.data_url + "/limits/").json())
# resources for the client
@property
def tooling(self) -> "ToolingResource":
try:
return self._tooling
except AttributeError:
if "Tooling" not in globals():
global ToolingResource
from .resources.tooling import ToolingResource
self._tooling = ToolingResource(self)
return self._tooling
@property
def metadata(self) -> "MetadataResource":
try:
return self._metadata
except AttributeError:
if "MetadataResource" not in globals():
global MetadataResource
from .resources.metadata import MetadataResource
self._metadata = MetadataResource(self)
return self._metadata
@tooling.deleter
def tooling(self):
try:
del self._tooling
except AttributeError:
pass