Skip to content

Error rework #489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ private void generateOperationExecutor(PythonWriter writer) {

var transportRequest = context.applicationProtocol().requestType();
var transportResponse = context.applicationProtocol().responseType();
var errorSymbol = CodegenUtils.getServiceError(context.settings());
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());
var configSymbol = CodegenUtils.getConfigSymbol(context.settings());

Expand Down Expand Up @@ -302,46 +301,54 @@ def _classify_error(
}
writer.addStdlibImport("typing", "Any");
writer.addStdlibImport("asyncio", "iscoroutine");
writer.addImports("smithy_core.exceptions", Set.of("SmithyException", "CallException"));
writer.pushState();
writer.putContext("request", transportRequest);
writer.putContext("response", transportResponse);
writer.putContext("plugin", pluginSymbol);
writer.putContext("config", configSymbol);
writer.write(
"""
async def _execute_operation[Input: SerializeableShape, Output: DeserializeableShape](
self,
input: Input,
plugins: list[$1T],
serialize: Callable[[Input, $5T], Awaitable[$2T]],
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
config: $5T,
plugins: list[${plugin:T}],
serialize: Callable[[Input, ${config:T}], Awaitable[${request:T}]],
deserialize: Callable[[${response:T}, ${config:T}], Awaitable[Output]],
config: ${config:T},
operation: APIOperation[Input, Output],
request_future: Future[RequestContext[Any, $2T]] | None = None,
response_future: Future[$3T] | None = None,
request_future: Future[RequestContext[Any, ${request:T}]] | None = None,
response_future: Future[${response:T}] | None = None,
) -> Output:
try:
return await self._handle_execution(
input, plugins, serialize, deserialize, config, operation,
request_future, response_future,
)
except Exception as e:
# Make sure every exception that we throw is an instance of SmithyException so
# customers can reliably catch everything we throw.
if not isinstance(e, SmithyException):
wrapped = CallException(str(e))
wrapped.__cause__ = e
e = wrapped

if request_future is not None and not request_future.done():
request_future.set_exception($4T(e))
request_future.set_exception(e)
if response_future is not None and not response_future.done():
response_future.set_exception($4T(e))

# Make sure every exception that we throw is an instance of $4T so
# customers can reliably catch everything we throw.
if not isinstance(e, $4T):
raise $4T(e) from e
response_future.set_exception(e)
raise

async def _handle_execution[Input: SerializeableShape, Output: DeserializeableShape](
self,
input: Input,
plugins: list[$1T],
serialize: Callable[[Input, $5T], Awaitable[$2T]],
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
config: $5T,
plugins: list[${plugin:T}],
serialize: Callable[[Input, ${config:T}], Awaitable[${request:T}]],
deserialize: Callable[[${response:T}, ${config:T}], Awaitable[Output]],
config: ${config:T},
operation: APIOperation[Input, Output],
request_future: Future[RequestContext[Any, $2T]] | None,
response_future: Future[$3T] | None,
request_future: Future[RequestContext[Any, ${request:T}]] | None,
response_future: Future[${response:T}] | None,
) -> Output:
operation_name = operation.schema.id.name
logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input)
Expand All @@ -350,11 +357,16 @@ def _classify_error(
plugin(config)

input_context = InputContext(request=input, properties=TypedProperties({"config": config}))
transport_request: $2T | None = None
output_context: OutputContext[Input, Output, $2T | None, $3T | None] | None = None
transport_request: ${request:T} | None = None
output_context: OutputContext[
Input,
Output,
${request:T} | None,
${response:T} | None
] | None = None

client_interceptors = cast(
list[Interceptor[Input, Output, $2T, $3T]], list(config.interceptors)
list[Interceptor[Input, Output, ${request:T}, ${response:T}]], list(config.interceptors)
)
interceptor_chain = InterceptorChain(client_interceptors)

Expand Down Expand Up @@ -455,24 +467,20 @@ await sleep(retry_token.retry_delay)

async def _handle_attempt[Input: SerializeableShape, Output: DeserializeableShape](
self,
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
interceptor: Interceptor[Input, Output, $2T, $3T],
context: RequestContext[Input, $2T],
config: $5T,
deserialize: Callable[[${response:T}, ${config:T}], Awaitable[Output]],
interceptor: Interceptor[Input, Output, ${request:T}, ${response:T}],
context: RequestContext[Input, ${request:T}],
config: ${config:T},
operation: APIOperation[Input, Output],
request_future: Future[RequestContext[Input, $2T]] | None,
) -> OutputContext[Input, Output, $2T, $3T | None]:
transport_response: $3T | None = None
request_future: Future[RequestContext[Input, ${request:T}]] | None,
) -> OutputContext[Input, Output, ${request:T}, ${response:T} | None]:
transport_response: ${response:T} | None = None
try:
# Step 7a: Invoke read_before_attempt
interceptor.read_before_attempt(context)

""",
pluginSymbol,
transportRequest,
transportResponse,
errorSymbol,
configSymbol);
""");
writer.popState();

boolean supportsAuth = !ServiceIndex.of(model).getAuthSchemes(service).isEmpty();
writer.pushState(new ResolveIdentitySection());
Expand Down Expand Up @@ -873,8 +881,8 @@ private void writeSharedOperationInit(PythonWriter writer, OperationShape operat
.orElse("The operation's input.");

writer.write("""
$L
""",docs);
$L
""", docs);
writer.write("");
writer.write(":param input: $L", inputDocs);
writer.write("");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ public static Symbol getPluginSymbol(PythonSettings settings) {
/**
* Gets the service error symbol.
*
* <p>This error is the top-level error for the client. Every error surfaced by
* the client MUST be a subclass of this so that customers can reliably catch all
* exceptions it raises. The client implementation will wrap any errors that aren't
* already subclasses.
* <p>This error is the top-level error for modeled client errors.
*
* @param settings The client settings, used to account for module configuration.
* @return Returns the symbol for the client's error class.
Expand All @@ -105,40 +102,6 @@ public static Symbol getServiceError(PythonSettings settings) {
.build();
}

/**
* Gets the service API error symbol.
*
* <p>This error is the parent class for all errors returned over the wire by the
* service, including unknown errors.
*
* @param settings The client settings, used to account for module configuration.
* @return Returns the symbol for the client's API error class.
*/
public static Symbol getApiError(PythonSettings settings) {
return Symbol.builder()
.name("ApiError")
.namespace(String.format("%s.models", settings.moduleName()), ".")
.definitionFile(String.format("./src/%s/models.py", settings.moduleName()))
.build();
}

/**
* Gets the unknown API error symbol.
*
* <p> This error is the parent class for all errors returned over the wire by
* the service which aren't in the model.
*
* @param settings The client settings, used to account for module configuration.
* @return Returns the symbol for unknown API errors.
*/
public static Symbol getUnknownApiError(PythonSettings settings) {
return Symbol.builder()
.name("UnknownApiError")
.namespace(String.format("%s.models", settings.moduleName()), ".")
.definitionFile(String.format("./src/%s/models.py", settings.moduleName()))
.build();
}

/**
* Gets the symbol for the http auth params.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import java.util.Locale;
import java.util.logging.Logger;
import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider;
import software.amazon.smithy.codegen.core.ReservedWords;
import software.amazon.smithy.codegen.core.ReservedWordsBuilder;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.codegen.core.SymbolProvider;
Expand Down Expand Up @@ -84,6 +83,10 @@ public PythonSymbolProvider(Model model, PythonSettings settings) {
var reservedMemberNamesBuilder = new ReservedWordsBuilder()
.loadWords(PythonSymbolProvider.class.getResource("reserved-member-names.txt"), this::escapeWord);

// Reserved words that only apply to error members.
var reservedErrorMembers = new ReservedWordsBuilder()
.loadWords(PythonSymbolProvider.class.getResource("reserved-error-member-names.txt"), this::escapeWord);

escaper = ReservedWordSymbolProvider.builder()
.nameReservedWords(reservedClassNames)
.memberReservedWords(reservedMemberNamesBuilder.build())
Expand All @@ -92,13 +95,8 @@ public PythonSymbolProvider(Model model, PythonSettings settings) {
.escapePredicate((shape, symbol) -> !StringUtils.isEmpty(symbol.getDefinitionFile()))
.buildEscaper();

// Reserved words that only apply to error members.
ReservedWords reservedErrorMembers = reservedMemberNamesBuilder
.put("code", "code_")
.build();

errorMemberEscaper = ReservedWordSymbolProvider.builder()
.memberReservedWords(reservedErrorMembers)
.memberReservedWords(reservedErrorMembers.build())
.escapePredicate((shape, symbol) -> !StringUtils.isEmpty(symbol.getDefinitionFile()))
.buildEscaper();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*/
package software.amazon.smithy.python.codegen.generators;

import java.util.Set;
import software.amazon.smithy.codegen.core.WriterDelegator;
import software.amazon.smithy.python.codegen.CodegenUtils;
import software.amazon.smithy.python.codegen.PythonSettings;
Expand All @@ -30,38 +29,15 @@ public void run() {
var serviceError = CodegenUtils.getServiceError(settings);
writers.useFileWriter(serviceError.getDefinitionFile(), serviceError.getNamespace(), writer -> {
writer.addDependency(SmithyPythonDependency.SMITHY_CORE);
writer.addImport("smithy_core.exceptions", "SmithyException");
writer.addImport("smithy_core.exceptions", "ModeledException");
writer.write("""
class $L(SmithyException):
""\"Base error for all errors in the service.""\"
pass
""", serviceError.getName());
});

var apiError = CodegenUtils.getApiError(settings);
writers.useFileWriter(apiError.getDefinitionFile(), apiError.getNamespace(), writer -> {
writer.addStdlibImports("typing", Set.of("Literal", "ClassVar"));
var unknownApiError = CodegenUtils.getUnknownApiError(settings);

writer.write("""
@dataclass
class $1L($2T):
""\"Base error for all API errors in the service.""\"
code: ClassVar[str]
fault: ClassVar[Literal["client", "server"]]
class $L(ModeledException):
""\"Base error for all errors in the service.

message: str

def __post_init__(self) -> None:
super().__init__(self.message)


@dataclass
class $3L($1L):
""\"Error representing any unknown api errors.""\"
code: ClassVar[str] = 'Unknown'
fault: ClassVar[Literal["client", "server"]] = "client"
""", apiError.getName(), serviceError, unknownApiError.getName());
Some exceptions do not extend from this class, including
synthetic, implicit, and shared exception types.
""\"
""", serviceError.getName());
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,6 @@ private static void writeIndexes(GenerationContext context, String projectName)
writeIndexFile(context, "docs/models/index.rst", "Models");
}


/**
* Write the readme in the docs folder describing instructions for generation
*
Expand All @@ -461,18 +460,18 @@ private static void writeDocsReadme(
GenerationContext context
) {
context.writerDelegator().useFileWriter("docs/README.md", writer -> {
writer.write("""
## Generating Documentation
Sphinx is used for documentation. You can generate HTML locally with the
following:
```
$$ uv pip install ".[docs]"
$$ cd docs
$$ make html
```
""");
writer.write("""
## Generating Documentation

Sphinx is used for documentation. You can generate HTML locally with the
following:

```
$$ uv pip install ".[docs]"
$$ cd docs
$$ make html
```
""");
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import software.amazon.smithy.model.traits.InputTrait;
import software.amazon.smithy.model.traits.OutputTrait;
import software.amazon.smithy.model.traits.RequiredTrait;
import software.amazon.smithy.model.traits.RetryableTrait;
import software.amazon.smithy.model.traits.SensitiveTrait;
import software.amazon.smithy.model.traits.StreamingTrait;
import software.amazon.smithy.python.codegen.CodegenUtils;
Expand Down Expand Up @@ -130,31 +131,40 @@ private void renderError() {
writer.addStdlibImports("typing", Set.of("Literal", "ClassVar"));
writer.addStdlibImport("dataclasses", "dataclass");

// TODO: Implement protocol-level customization of the error code
var fault = errorTrait.getValue();
var code = shape.getId().getName();
var symbol = symbolProvider.toSymbol(shape);
var apiError = CodegenUtils.getApiError(settings);
var baseError = CodegenUtils.getServiceError(settings);
writer.pushState(new ErrorSection(symbol));
writer.putContext("retryable", false);
writer.putContext("throttling", false);

var retryableTrait = shape.getTrait(RetryableTrait.class);
if (retryableTrait.isPresent()) {
writer.putContext("retryable", true);
writer.putContext("throttling", retryableTrait.get().getThrottling());
}
writer.write("""
@dataclass(kw_only=True)
class $1L($2T):
${5C|}
${4C|}

code: ClassVar[str] = $3S
fault: ClassVar[Literal["client", "server"]] = $4S
fault: Literal["client", "server"] | None = $3S
${?retryable}
is_retry_safe: bool | None = True
${?throttling}
is_throttle: bool = True
${/throttling}
${/retryable}

${5C|}

message: str
${6C|}

${7C|}

${8C|}

""",
symbol.getName(),
apiError,
code,
baseError,
fault,
writer.consumer(w -> writeClassDocs(true)),
writer.consumer(w -> writeProperties()),
Expand Down Expand Up @@ -325,7 +335,9 @@ private void writeMemberDocs(MemberShape member) {

String memberName = symbolProvider.toMemberName(member);
String docs = writer.formatDocs(String.format(":param %s: %s%s",
memberName, descriptionPrefix, trait.getValue()));
memberName,
descriptionPrefix,
trait.getValue()));
writer.write(docs);
});
}
Expand Down
Loading
Loading