from asyncio import Task, create_task
from pydoc import resolve
from typing import Any, AsyncIterator, Iterator, Literal, NamedTuple, TypeVar, Generic
from datetime import datetime, date
from ..client import AsyncSalesforceClient, SalesforceClient
from .fields import ListField, object_fields, query_fields
from .sobject import SObject, SObjectList
from ..formatting import quote_soql_value
from .._models import QueryResultJSON, SObjectRecordJSON
BooleanOperator = Literal["AND", "OR", "NOT"]
Comparator = Literal[
"=", "!=", "<>", ">", ">=", "<", "<=", "LIKE", "INCLUDES", "IN", "NOT IN"
]
AGGREGATE_FUNCTIONS = ["AVG", "COUNT", "COUNT_DISTINCT", "MIN", "MAX", "SUM"]
[docs]
class Comparison:
prop: str
comparator: Comparator
value: "SoqlQuery[Any] | str | bool | datetime | date | None"
[docs]
def __init__(
self,
prop: str,
cmp: Comparator,
value: "SoqlQuery[Any] | str | bool | datetime | date | None",
):
self.prop = prop
self.comparator = cmp
self.value = value
def __str__(self):
if isinstance(self.value, SoqlQuery):
return f"{self.prop} {self.comparator} ({str(self.value)})"
elif self.comparator == "IN" and isinstance(self.value, str):
return f"{self.prop} {self.comparator} ({self.value})"
return f"{self.prop} {self.comparator} {quote_soql_value(self.value)}"
[docs]
def EQ(prop: str, value):
return Comparison(prop, "=", value)
[docs]
def NE(prop: str, value):
return Comparison(prop, "!=", value)
[docs]
def GT(prop: str, value):
return Comparison(prop, ">", value)
[docs]
def GE(prop: str, value):
return Comparison(prop, ">=", value)
[docs]
def LT(prop: str, value):
return Comparison(prop, "<", value)
[docs]
def LE(prop: str, value):
return Comparison(prop, "<=", value)
[docs]
def LIKE(prop: str, value):
return Comparison(prop, "LIKE", value)
[docs]
def INCLUDES(prop: str, value):
return Comparison(prop, "INCLUDES", value)
[docs]
def IN(prop: str, value):
return Comparison(prop, "IN", value)
[docs]
def NOT_IN(prop: str, value):
return Comparison(prop, "NOT IN", value)
[docs]
class BooleanOperation:
operator: BooleanOperator
conditions: list["Comparison | BooleanOperation | str"]
[docs]
def __init__(
self,
operator: BooleanOperator,
conditions: list["Comparison | BooleanOperation | str"],
):
self.operator = operator
self.conditions = conditions
def __str__(self):
formatted_conditions = [
(
str(condition)
if isinstance(condition, Comparison)
else "(" + str(condition) + ")"
)
for condition in self.conditions
]
return f" {self.operator} ".join(formatted_conditions)
[docs]
def OR(*conditions: "Comparison | BooleanOperation | str"):
return BooleanOperation("OR", list(conditions))
[docs]
def AND(*conditions: "Comparison | BooleanOperation | str"):
return BooleanOperation("AND", list(conditions))
[docs]
class NOT(BooleanOperation):
[docs]
def __init__(self, condition: "Comparison | BooleanOperation | str"):
super().__init__("NOT", [condition])
def __str__(self):
return f"NOT ({str(self.conditions[0])})"
[docs]
class Order(NamedTuple):
field: str
direction: Literal["ASC", "DESC"]
def __str__(self):
return f"{self.field} {self.direction}"
_SObject = TypeVar("_SObject", bound=SObject)
_SObjectJSON = TypeVar("_SObjectJSON", bound=dict[str, Any])
[docs]
def resolve_client(
sobject_type: type[_SObject], connection: str | SalesforceClient | None
):
if isinstance(connection, str):
return SalesforceClient.get_connection(connection)
if isinstance(connection, SalesforceClient):
return connection
return SalesforceClient.get_connection(sobject_type.attributes.connection)
[docs]
def resolve_async_client(
sobject_type: type[_SObject], connection: str | AsyncSalesforceClient | None
):
if isinstance(connection, str):
return AsyncSalesforceClient.get_connection(connection)
if isinstance(connection, AsyncSalesforceClient):
return connection
return AsyncSalesforceClient.get_connection(sobject_type.attributes.connection)
[docs]
class QueryResultBatch(Generic[_SObject]):
"""
A generic class to represent results returned by the Salesforce SOQL Query API.
Attributes:
done (bool):
totalSize (int):
records (list[T]):
nextRecordsUrl (str, optional):
"""
done: bool
"Indicates whether all records have been retrieved (True) or if more batches exist (False)"
totalSize: int
"The total number of records that match the query criteria"
records: list[_SObject]
"The list of records returned by the query"
nextRecordsUrl: str | None
"URL to the next batch of records, if more exist"
_connection: str | None
_sobject_type: type[_SObject]
"The SObject type this QueryResult contains records for"
query_locator: str | None = None
batch_size: int | None = None
[docs]
def __init__(
self,
sobject_type: type[_SObject],
/,
done: bool = True,
totalSize: int = 0,
records: list[SObjectRecordJSON] | None = None,
nextRecordsUrl: str | None = None,
connection_name: str | None = None,
):
"""
Initialize a QueryResult object from Salesforce API response data.
Args:
**kwargs: Key-value pairs from the Salesforce API response.
"""
self._connection = connection_name
self._sobject_type = sobject_type
self.done = done
self.totalSize = totalSize
self.records = SObjectList(
[sobject_type(**record) for record in records] # type: ignore
if records
else []
)
self.nextRecordsUrl = nextRecordsUrl
if self.nextRecordsUrl:
# nextRecordsUrl looks like this:
# /services/data/v63.0/query/01gRO0000016PIAYA2-500
self.query_locator, batch_size = self.nextRecordsUrl.rsplit(
"/", maxsplit=1
)[1].rsplit("-", maxsplit=1)
self.batch_size = int(batch_size)
[docs]
def query_more(self) -> "QueryResultBatch[_SObject]":
if not self.nextRecordsUrl:
raise ValueError("Cannot get more records without nextRecordsUrl")
result: QueryResultJSON = (
SalesforceClient.get_connection(self._connection)
.get(self.nextRecordsUrl)
.json()
)
return QueryResultBatch(
self._sobject_type,
connection_name=self._connection,
**result, # type: ignore
)
[docs]
async def query_more_async(self) -> "QueryResultBatch[_SObject]":
if not self.nextRecordsUrl:
raise ValueError("Cannot get more records without nextRecordsUrl")
result: QueryResultJSON = (
await AsyncSalesforceClient.get_connection(self._connection).get(
self.nextRecordsUrl
)
).json()
return QueryResultBatch(
self._sobject_type,
connection_name=self._connection,
**result, # type: ignore
)
[docs]
class QueryResult(Generic[_SObject]):
batches: list[QueryResultBatch[_SObject]]
total_size: int
batch_index: int = 0
record_index: int = 0
_async_tasks: list[Task[QueryResultBatch[_SObject]]] | None
[docs]
def __init__(
self,
batches: list[QueryResultBatch[_SObject]],
_async_tasks: list[Task[QueryResultBatch[_SObject]]] | None = None,
):
self.batches = batches
self.total_size = batches[0].totalSize
self._async_tasks = _async_tasks
[docs]
def copy(self) -> "QueryResult[_SObject]":
"""Perform a shallow copy of the QueryResult object."""
return QueryResult(self.batches, self._async_tasks)
def __iter__(self) -> Iterator[_SObject]:
return self.copy()
def __aiter__(self) -> AsyncIterator[_SObject]:
if not self.done:
self.schedule_async_tasks()
return self.copy()
def __len__(self):
return self.total_size
@property
def done(self):
return self.batches[self.batch_index].done
[docs]
def as_list(self) -> SObjectList[_SObject]:
return SObjectList(
self, connection=self.batches[0]._sobject_type.attributes.connection
)
[docs]
async def as_list_async(self) -> SObjectList[_SObject]:
return await SObjectList.async_init(
self, self.batches[0]._sobject_type.attributes.connection
)
async def _fetch_query_locator_batch(self, query_locator_url: str):
connection = AsyncSalesforceClient.get_connection(self.batches[0]._connection)
result: QueryResultJSON = (await connection.get(query_locator_url)).json()
return QueryResultBatch(
self.batches[0]._sobject_type,
connection_name=self.batches[0]._connection,
**result, ## type: ignore
)
[docs]
def schedule_async_tasks(self):
assert self.batches[0].nextRecordsUrl is not None, (
"Cannot iterate with no query locator"
)
url_root, _ = self.batches[0].nextRecordsUrl.rsplit("-", maxsplit=1)
batch_size = len(self.batches[0].records)
fetched_record_count = batch_size * self.batch_index
self._async_tasks = [
create_task(self._fetch_query_locator_batch(f"{url_root}-{index}"))
for index in range(fetched_record_count, len(self), batch_size)
]
def __next__(self) -> _SObject:
try:
return self.batches[self.batch_index].records[self.record_index]
except IndexError:
if self.done:
raise StopIteration
if self.batch_index >= (len(self.batches) - 1):
self.batches.append(self.batches[self.batch_index].query_more())
self.batch_index += 1
self.record_index = 0
return self.batches[self.batch_index].records[self.record_index]
finally:
self.record_index += 1
async def __anext__(self) -> _SObject:
try:
return self.batches[self.batch_index].records[self.record_index]
except IndexError:
if self.done:
raise StopAsyncIteration
if self._async_tasks:
self.batches.append(await self._async_tasks.pop(0))
if not self._async_tasks:
self._async_tasks = None
elif self.batch_index >= (len(self.batches) - 1):
self.batches.append(await self.batches[-1].query_more_async())
self.batch_index += 1
self.record_index = 0
return self.batches[self.batch_index].records[self.record_index]
finally:
self.record_index += 1
[docs]
class SoqlQuery(Generic[_SObject]):
sobject_type: type[_SObject]
_object_relationship_name: str | None = None
_where: Comparison | BooleanOperation | str | None = None
_grouping: list[str] | None = None
_having: Comparison | BooleanOperation | str | None = None
_limit: int | None = None
_offset: int | None = None
_order: list[Order | str] | None = None
_subqueries: dict[str, "SoqlQuery[Any]"]
_include_deleted: bool
[docs]
def __init__(self, sobject_type: type[_SObject], include_deleted: bool = False):
self.sobject_type = sobject_type
self._subqueries = {}
self._include_deleted = include_deleted
@property
def fields(self):
fields: list[str] = []
obj_fields = object_fields(self.sobject_type)
for field in query_fields(self.sobject_type):
if isinstance(field_def := obj_fields.get(field), ListField):
subquery = self._subqueries.get(field)
if not subquery:
subquery = select(field_def._nested_type)
subquery._object_relationship_name = field
fields.append(f"({str(subquery)})")
else:
fields.append(field)
return fields
[docs]
def filter_subqueries(self, **subqueries: "SoqlQuery[Any]"):
"""
Configure Parent-To-Child Relationship queries
By default, all records are returned in the subquery (no filtering).
https://developer.salesforce.com/docs/atlas.en-us.soql_sosl.meta/soql_sosl/sforce_api_calls_soql_relationships_query_using.htm
Args:
**subqueries: A dictionary of field names and SoqlQuery objects.
Returns:
self: The current SoqlQuery object.
"""
for field, subquery in subqueries.items():
assert isinstance(object_fields(self.sobject_type).get(field), ListField), (
f"Field '{field}' is not a ListField"
)
subquery._object_relationship_name = field
self._subqueries[field] = subquery
return self
@property
def sobject_name(self) -> str:
return self.sobject_type.attributes.type
[docs]
@classmethod
def build_conditional(cls, arg: str, value) -> Comparison | NOT:
op = "="
negated = arg.startswith("NOT__")
if negated:
arg = arg.removeprefix("NOT__")
if arg.endswith("__ne"):
arg = arg.removesuffix("__ne")
op = "!="
elif arg.endswith("__gt"):
arg = arg.removesuffix("__gt")
op = ">"
elif arg.endswith("__lt"):
arg = arg.removesuffix("__lt")
op = "<"
elif arg.endswith("__ge"):
arg = arg.removesuffix("__ge")
op = ">="
elif arg.endswith("__le"):
arg = arg.removesuffix("__le")
op = "<="
elif arg.endswith("__in"):
arg = arg.removesuffix("__in")
op = "IN"
elif arg.endswith("__like"):
arg = arg.removesuffix("__like")
op = "LIKE"
elif arg.endswith("__includes"):
arg = arg.removesuffix("__includes")
op = "INCLUDES"
if any(arg.startswith(f"{func}__") for func in AGGREGATE_FUNCTIONS):
func, arg = arg.split("__", maxsplit=1)
arg = f"{func}({arg})"
if negated:
return NOT(Comparison(arg, op, value))
else:
return Comparison(arg, op, value)
[docs]
@classmethod
def build_conditional_clause(
cls,
kwargs: dict[str, Any],
mode: Literal["any", "all"] = "all",
) -> Comparison | BooleanOperation:
assert len(kwargs) > 0
if len(kwargs) == 1:
arg, value = next(iter(kwargs.items()))
return cls.build_conditional(arg, value)
conditions = (
cls.build_conditional(arg, value) for arg, value in kwargs.items()
)
if mode == "any":
return OR(*conditions)
elif mode == "all":
return AND(*conditions)
else:
raise ValueError(f"Invalid mode: {mode}")
[docs]
def where(
self: "SoqlQuery[_SObject]",
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs: Any,
) -> "SoqlQuery[_SObject]":
if _raw:
self._where = _raw
else:
self._where = self.build_conditional_clause(kwargs, _mode)
return self
[docs]
def and_where(
self,
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs: Any,
):
assert self._where is not None, "where() must be called before and_where()"
if _raw:
self._where = AND(self._where, _raw)
else:
self._where = AND(self._where, self.build_conditional_clause(kwargs, _mode))
return self
[docs]
def or_where(
self,
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs: Any,
):
assert self._where is not None, "where() must be called before or_where()"
if _raw:
self._where = OR(self._where, _raw)
else:
self._where = OR(self._where, self.build_conditional_clause(kwargs, _mode))
return self
[docs]
def group_by(self, *fields: str):
self._grouping = list(fields)
return self
[docs]
def having(
self,
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs: Any,
):
if _raw:
self._having = _raw
else:
self._having = self.build_conditional_clause(kwargs, _mode)
return self
[docs]
def and_having(
self,
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs: Any,
):
assert self._having is not None, "having() must be called before and_having()"
if _raw:
self._having = AND(self._having, _raw)
else:
self._having = AND(
self._having, self.build_conditional_clause(kwargs, _mode)
)
return self
[docs]
def or_having(
self,
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs: Any,
):
assert self._having is not None, "having() must be called before or_having()"
if _raw:
self._having = OR(self._having, _raw)
else:
self._having = OR(
self._having, self.build_conditional_clause(kwargs, _mode)
)
return self
[docs]
def limit(self, limit: int):
self._limit = limit
return self
[docs]
def offset(self, offset: int):
self._offset = offset
return self
[docs]
def order_by(self, *orders: Order, **kw_orders: Literal["ASC", "DESC"]):
self._order = list(orders)
self._order.extend(
Order(field, direction) for field, direction in kw_orders.items()
)
return self
def __str__(self):
return self.format()
[docs]
def count(self, connection: SalesforceClient | str | None = None) -> int:
"""
Executes a count query instead of fetching records.
Returns the count of records that match the query criteria.
Returns:
int: Number of records matching the query criteria
"""
# Execute the query
count_result = self.execute("COUNT()", connection=connection)
# Count query returns a list with a single record containing the count
return len(count_result)
[docs]
async def count_async(
self, connection: AsyncSalesforceClient | str | None = None
) -> int:
"""
Executes a count query instead of fetching records.
Returns the count of records that match the query criteria.
Returns:
int: Number of records matching the query criteria
"""
# Execute the query
count_result = await self.execute_async("COUNT()", connection=connection)
# Count query returns a list with a single record containing the count
return len(count_result)
[docs]
async def execute_async(
self,
*_fields: str,
connection: AsyncSalesforceClient | str | None = None,
**callout_options,
) -> QueryResult[_SObject]:
"""
Executes the SOQL query and returns the first batch of results (up to 2000 records).
"""
if _fields:
fields = list(_fields)
else:
fields = self.fields
client = resolve_async_client(self.sobject_type, connection)
result: QueryResultJSON
assert not (self.sobject_type.attributes.tooling and self._include_deleted), (
"Tooling API does not support query deleted records (QueryAll)"
)
if self.sobject_type.attributes.tooling:
url = f"{client.data_url}/tooling/query/"
elif self._include_deleted:
url = f"{client.data_url}/queryAll/"
else:
url = f"{client.data_url}/query/"
result = (
await client.get(url, params={"q": self.format(fields)}, **callout_options)
).json()
batch = QueryResultBatch(
self.sobject_type, connection_name=client.connection_name, **result
) # type: ignore
return QueryResult([batch])
[docs]
def execute(
self,
*_fields: str,
connection: SalesforceClient | str | None = None,
**callout_options: Any,
) -> QueryResult[_SObject]:
"""
Executes the SOQL query and returns the first batch of results (up to 2000 records).
"""
if _fields:
fields = list(_fields)
else:
fields = self.fields
client = resolve_client(self.sobject_type, connection)
result: QueryResultJSON
assert not (self.sobject_type.attributes.tooling and self._include_deleted), (
"Tooling API does not support query deleted records (QueryAll)"
)
if self.sobject_type.attributes.tooling:
url = f"{client.data_url}/tooling/query/"
elif self._include_deleted:
url = f"{client.data_url}/queryAll/"
else:
url = f"{client.data_url}/query/"
result = client.get(
url, params={"q": self.format(fields)}, **callout_options
).json()
batch = QueryResultBatch(
self.sobject_type, connection_name=client.connection_name, **result
) # type: ignore
return QueryResult([batch])
[docs]
def execute_bulk(
self, connection_name: str | None = None
) -> "BulkQueryResult[_SObject]":
global BulkApiQueryJob, BulkQueryResult
try:
_ = BulkApiQueryJob
except NameError:
from .bulk import BulkApiQueryJob, BulkQueryResult
connection = SalesforceClient.get_connection(
connection_name or self.sobject_type.attributes.connection
)
bulk_job: BulkApiQueryJob[_SObject] = BulkApiQueryJob.init_job(
self,
connection,
operation="queryAll" if self._include_deleted else "query",
)
_ = bulk_job.monitor_until_complete()
return bulk_job.result
[docs]
async def execute_bulk_async(
self, connection: AsyncSalesforceClient | str | None = None
) -> "BulkQueryResult[_SObject]":
global BulkApiQueryJob, BulkQueryResult
try:
_ = BulkApiQueryJob
except NameError:
from .bulk import BulkApiQueryJob, BulkQueryResult
connection = resolve_async_client(self.sobject_type, connection)
bulk_job: BulkApiQueryJob[_SObject] = await BulkApiQueryJob.init_job_async(
self,
connection,
operation="queryAll" if self._include_deleted else "query",
)
_ = await bulk_job.monitor_until_complete_async()
return bulk_job.result
def __iter__(self):
return self.execute()
[docs]
def select(
sobject_type: type[_SObject], include_deleted: bool = False
) -> SoqlQuery[_SObject]:
return SoqlQuery(sobject_type, include_deleted=include_deleted)