from typing import Any, AsyncIterator, Iterator, Literal, NamedTuple, TypeVar, Generic
from datetime import datetime, date
from ..client import SalesforceClient
from .fields import ListField
from .sobject import SObject, SObjectList
from ..formatting import quote_soql_value
from .._models import QueryResultJSON, SObjectRecordJSON
BooleanOperator = Literal["AND", "OR", "NOT"]
Comparator = Literal["=", "!=", "<>", ">", ">=", "<", "<=", "LIKE", "INCLUDES", "IN"]
AGGREGATE_FUNCTIONS = ["AVG", "COUNT", "COUNT_DISTINCT", "MIN", "MAX", "SUM"]
[docs]
class Comparison:
prop: str
comparator: Comparator
value: "SoqlQuery | str | bool | datetime | date | None"
[docs]
def __init__(self, prop: str, op, value):
self.prop = prop
self.operator = op
self.value = value
def __str__(self):
if isinstance(self.value, SoqlQuery):
return f"{self.prop} {self.operator} ({str(self.value)})"
elif self.operator == "IN" and isinstance(self.value, str):
return f"{self.prop} {self.operator} ({self.value})"
return f"{self.prop} {self.operator} {quote_soql_value(self.value)}"
[docs]
def EQ(prop: str, value):
return Comparison(prop, "=", value)
[docs]
def NE(prop: str, value):
return Comparison(prop, "!=", value)
[docs]
def GT(prop: str, value):
return Comparison(prop, ">", value)
[docs]
def GE(prop: str, value):
return Comparison(prop, ">=", value)
[docs]
def LT(prop: str, value):
return Comparison(prop, "<", value)
[docs]
def LE(prop: str, value):
return Comparison(prop, "<=", value)
[docs]
def LIKE(prop: str, value):
return Comparison(prop, "LIKE", value)
[docs]
def INCLUDES(prop: str, value):
return Comparison(prop, "INCLUDES", value)
[docs]
def IN(prop: str, value):
return Comparison(prop, "IN", value)
[docs]
def NOT_IN(prop: str, value):
return Comparison(prop, "NOT IN", value)
[docs]
class BooleanOperation:
operator: BooleanOperator
conditions: list["Comparison | BooleanOperation | str"]
[docs]
def __init__(
self,
operator: BooleanOperator,
conditions: list["Comparison | BooleanOperation | str"],
):
self.operator = operator
self.conditions = conditions
def __str__(self):
formatted_conditions = [
str(condition)
if isinstance(condition, Comparison)
else "(" + str(condition) + ")"
for condition in self.conditions
]
return f" {self.operator} ".join(formatted_conditions)
[docs]
def OR(*conditions: "Comparison | BooleanOperation | str"):
return BooleanOperation("OR", list(conditions))
[docs]
def AND(*conditions: "Comparison | BooleanOperation | str"):
return BooleanOperation("AND", list(conditions))
[docs]
class NOT(BooleanOperation):
[docs]
def __init__(self, condition: "Comparison | BooleanOperation | str"):
super().__init__("NOT", [condition])
def __str__(self):
return f"NOT ({str(self.conditions[0])})"
[docs]
class Order(NamedTuple):
field: str
direction: Literal["ASC", "DESC"]
def __str__(self):
return f"{self.field} {self.direction}"
_SObject = TypeVar("_SObject", bound=SObject)
_SObjectJSON = TypeVar("_SObjectJSON", bound=dict[str, Any])
[docs]
class QueryResultBatch(Generic[_SObject]):
"""
A generic class to represent results returned by the Salesforce SOQL Query API.
Attributes:
done (bool):
totalSize (int):
records (list[T]):
nextRecordsUrl (str, optional):
"""
done: bool
"Indicates whether all records have been retrieved (True) or if more batches exist (False)"
totalSize: int
"The total number of records that match the query criteria"
records: list[_SObject]
"The list of records returned by the query"
nextRecordsUrl: str | None
"URL to the next batch of records, if more exist"
_connection: SalesforceClient
_sobject_type: type[_SObject]
"The SObject type this QueryResult contains records for"
query_locator: str | None = None
batch_size: int | None = None
[docs]
def __init__(
self,
sobject_type: type[_SObject],
/,
done: bool = True,
totalSize: int = 0,
records: list[SObjectRecordJSON] | None = None,
nextRecordsUrl: str | None = None,
connection: SalesforceClient | None = None,
):
"""
Initialize a QueryResult object from Salesforce API response data.
Args:
**kwargs: Key-value pairs from the Salesforce API response.
"""
self._connection = connection or SalesforceClient.get_connection(sobject_type.attributes.connection) # type: ignore
self._sobject_type = sobject_type
self.done = done
self.totalSize = totalSize
self.records = SObjectList(
[
sobject_type(**record)
for record in records # type: ignore
]
if records
else []
)
self.nextRecordsUrl = nextRecordsUrl
if self.nextRecordsUrl:
# nextRecordsUrl looks like this:
# /services/data/v63.0/query/01gRO0000016PIAYA2-500
self.query_locator, batch_size = self.nextRecordsUrl.rsplit(
"/", maxsplit=1
)[1].rsplit("-", maxsplit=1)
self.batch_size = int(batch_size)
[docs]
def query_more(self):
if not self.nextRecordsUrl:
raise ValueError("Cannot get more records without nextRecordsUrl")
result: QueryResultJSON = self._connection.get(self.nextRecordsUrl).json()
return QueryResultBatch(self._sobject_type, connection=self._connection, **result) # type: ignore
[docs]
async def query_more_async(self):
if not self.nextRecordsUrl:
raise ValueError("Cannot get more records without nextRecordsUrl")
result: QueryResultJSON = (await self._connection.as_async.get(self.nextRecordsUrl)).json()
return QueryResultBatch(self._sobject_type, connection=self._connection, **result) # type: ignore
[docs]
class QueryResult(Generic[_SObject]):
batches: list[QueryResultBatch[_SObject]]
total_size: int
batch_index: int = 0
record_index: int = 0
[docs]
def __init__(self, *batches: QueryResultBatch[_SObject]):
self.batches = [*batches]
self.total_size = batches[0].totalSize
def __len__(self):
return self.total_size
@property
def done(self):
return self.batches[-1].done
[docs]
def as_list(self) -> SObjectList[_SObject]:
return SObjectList(self, connection=self.batches[0]._sobject_type.attributes.connection)
[docs]
def copy(self) -> "QueryResult[_SObject]":
"""Perform a shallow copy of the QueryResult object."""
return QueryResult(*self.batches)
def __iter__(self) -> Iterator[_SObject]:
return self.copy()
def __aiter__(self) -> AsyncIterator[_SObject]:
return self.copy()
def __next__(self) -> _SObject:
try:
record = self.batches[self.batch_index].records[self.record_index]
except IndexError:
if self.batches[-1].done:
raise StopIteration
self.batches.append(self.batches[self.batch_index].query_more())
self.batch_index += 1
self.record_index = 0
record = self.batches[self.batch_index].records[self.record_index]
self.record_index += 1
return record
async def __anext__(self) -> _SObject:
try:
return self.batches[-1].records[self.batch_index]
except IndexError:
if self.batches[-1].done:
raise StopAsyncIteration
self.batches.append(await self.batches[-1].query_more_async())
self.batch_index = 0
return self.batches[-1].records[self.batch_index]
finally:
self.batch_index += 1
[docs]
class SoqlQuery(Generic[_SObject]):
sobject_type: type[_SObject]
_object_relationship_name: str | None = None
_where: Comparison | BooleanOperation | str | None = None
_grouping: list[str] | None = None
_having: Comparison | BooleanOperation | str | None = None
_limit: int | None = None
_offset: int | None = None
_order: list[Order | str] | None = None
_subqueries: dict[str, "SoqlQuery"]
_include_deleted: bool
[docs]
def __init__(
self,
sobject_type: type[_SObject],
include_deleted: bool = False
):
self.sobject_type = sobject_type
self._subqueries = {}
self._include_deleted = include_deleted
@property
def fields(self):
fields = []
for field in self.sobject_type.query_fields():
if isinstance(field_def := self.sobject_type._fields.get(field), ListField):
subquery = self._subqueries.get(field)
if not subquery:
subquery = field_def._nested_type.query()
subquery._object_relationship_name = field
fields.append(f"({str(subquery)})")
else:
fields.append(field)
return fields
[docs]
def filter_subqueries(self, **subqueries: "SoqlQuery"):
"""
Configure Parent-To-Child Relationship queries
By default, all records are returned in the subquery (no filtering).
https://developer.salesforce.com/docs/atlas.en-us.soql_sosl.meta/soql_sosl/sforce_api_calls_soql_relationships_query_using.htm
Args:
**subqueries: A dictionary of field names and SoqlQuery objects.
Returns:
self: The current SoqlQuery object.
"""
for field, subquery in subqueries.items():
assert isinstance(self.sobject_type._fields.get(field), ListField), (
f"Field '{field}' is not a ListField"
)
subquery._object_relationship_name = field
self._subqueries[field] = subquery
return self
@property
def sobject_name(self) -> str:
return self.sobject_type.attributes.type
def _sf_connection(self):
return self.sobject_type._client_connection()
[docs]
@classmethod
def build_conditional(cls, arg: str, value) -> Comparison | NOT:
op = "="
negated = arg.startswith("NOT__")
if negated:
arg = arg.removeprefix("NOT__")
if arg.endswith("__ne"):
arg = arg.removesuffix("__ne")
op = "!="
elif arg.endswith("__gt"):
arg = arg.removesuffix("__gt")
op = ">"
elif arg.endswith("__lt"):
arg = arg.removesuffix("__lt")
op = "<"
elif arg.endswith("__ge"):
arg = arg.removesuffix("__ge")
op = ">="
elif arg.endswith("__le"):
arg = arg.removesuffix("__le")
op = "<="
elif arg.endswith("__in"):
arg = arg.removesuffix("__in")
op = "IN"
elif arg.endswith("__like"):
arg = arg.removesuffix("__like")
op = "LIKE"
elif arg.endswith("__includes"):
arg = arg.removesuffix("__includes")
op = "INCLUDES"
if any(arg.startswith(f"{func}__") for func in AGGREGATE_FUNCTIONS):
func, arg = arg.split("__", maxsplit=1)
arg = f"{func}({arg})"
if negated:
return NOT(Comparison(arg, op, value))
else:
return Comparison(arg, op, value)
[docs]
@classmethod
def build_conditional_clause(
cls,
kwargs: dict[str, Any],
mode: Literal["any", "all"] = "all",
) -> Comparison | BooleanOperation:
assert len(kwargs) > 0
if len(kwargs) == 1:
arg, value = next(iter(kwargs.items()))
return cls.build_conditional(arg, value)
conditions = (
cls.build_conditional(arg, value) for arg, value in kwargs.items()
)
if mode == "any":
return OR(*conditions)
elif mode == "all":
return AND(*conditions)
else:
raise ValueError(f"Invalid mode: {mode}")
[docs]
def where(
self,
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs,
):
if _raw:
self._where = _raw
else:
self._where = self.build_conditional_clause(kwargs, _mode)
return self
[docs]
def and_where(
self,
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs: Any,
):
assert self._where is not None, "where() must be called before and_where()"
if _raw:
self._where = AND(self._where, _raw)
else:
self._where = AND(self._where, self.build_conditional_clause(kwargs, _mode))
return self
[docs]
def or_where(
self,
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs: Any,
):
assert self._where is not None, "where() must be called before or_where()"
if _raw:
self._where = OR(self._where, _raw)
else:
self._where = OR(self._where, self.build_conditional_clause(kwargs, _mode))
return self
[docs]
def group_by(self, *fields: str):
self._grouping = list(fields)
return self
[docs]
def having(
self,
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs,
):
if _raw:
self._having = _raw
else:
self._having = self.build_conditional_clause(kwargs, _mode)
return self
[docs]
def and_having(
self,
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs: Any,
):
assert self._having is not None, "having() must be called before and_having()"
if _raw:
self._having = AND(self._having, _raw)
else:
self._having = AND(
self._having, self.build_conditional_clause(kwargs, _mode)
)
return self
[docs]
def or_having(
self,
_raw: Comparison | BooleanOperation | str | None = None,
_mode: Literal["any", "all"] = "all",
**kwargs: Any,
):
assert self._having is not None, "having() must be called before or_having()"
if _raw:
self._having = OR(self._having, _raw)
else:
self._having = OR(
self._having, self.build_conditional_clause(kwargs, _mode)
)
return self
[docs]
def limit(self, limit: int):
self._limit = limit
return self
[docs]
def offset(self, offset: int):
self._offset = offset
return self
[docs]
def order_by(self, *orders: Order, **kw_orders: Literal["ASC", "DESC"]):
self._order = list(orders)
self._order.extend(
Order(field, direction) for field, direction in kw_orders.items()
)
return self
def __str__(self):
return self.format()
[docs]
def count(self) -> int:
"""
Executes a count query instead of fetching records.
Returns the count of records that match the query criteria.
Returns:
int: Number of records matching the query criteria
"""
# Execute the query
count_result = self.execute("COUNT()")
# Count query returns a list with a single record containing the count
return len(count_result)
[docs]
def execute(self, *_fields: str) -> QueryResult[_SObject]:
"""
Executes the SOQL query and returns the first batch of results (up to 2000 records).
"""
if _fields:
fields = list(_fields)
else:
fields = self.fields
client = self._sf_connection()
result: QueryResultJSON
assert not (self.sobject_type.attributes.tooling and self._include_deleted), \
"Tooling API does not support query deleted records (QueryAll)"
if self.sobject_type.attributes.tooling:
url = f"{client.data_url}/tooling/query/"
elif self._include_deleted:
url = f"{client.data_url}/queryAll/"
else:
url = f"{client.data_url}/query/"
result = client.get(url, params={"q": self.format(fields)}).json()
batch = QueryResultBatch(self.sobject_type, connection=client, **result) # type: ignore
return QueryResult(batch)