Skip to content

Commit

Permalink
Improve audience support in OAuth2 authenticator
Browse files Browse the repository at this point in the history
* Add validation of JWT's audience (`aud`) field
* Add support for multiple-valued audience field
* Send audience in the authentication request
  Some IdPs (ex. Auth0) require including audience in the authorization
  request in order to obtain an access token in JWT format
  • Loading branch information
lukasz-walkiewicz authored and kokosing committed Jan 12, 2021
1 parent 13910a8 commit 1b68565
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class OAuth2Config
private String jwksUrl;
private String clientId;
private String clientSecret;
private Optional<String> audience = Optional.empty();
private Duration challengeTimeout = new Duration(15, TimeUnit.MINUTES);
private Optional<String> userMappingPattern = Optional.empty();
private Optional<File> userMappingFile = Optional.empty();
Expand Down Expand Up @@ -120,6 +121,19 @@ public OAuth2Config setClientSecret(String clientSecret)
return this;
}

public Optional<String> getAudience()
{
return audience;
}

@Config("http-server.authentication.oauth2.audience")
@ConfigDescription("The required audience of a token")
public OAuth2Config setAudience(String audience)
{
this.audience = Optional.ofNullable(audience);
return this;
}

@MinDuration("1ms")
@NotNull
public Duration getChallengeTimeout()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,32 @@
import java.time.Instant;
import java.util.Optional;

import static com.github.scribejava.core.model.OAuthConstants.REDIRECT_URI;
import static com.github.scribejava.core.model.OAuthConstants.STATE;
import static java.util.Objects.requireNonNull;

public class ScribeJavaOAuth2Client
implements OAuth2Client
{
private final DynamicCallbackOAuth2Service service;
private final Optional<String> audience;

@Inject
public ScribeJavaOAuth2Client(OAuth2Config config)
{
requireNonNull(config, "config is null");
service = new DynamicCallbackOAuth2Service(config);
audience = config.getAudience();
}

@Override
public URI getAuthorizationUri(String state, URI callbackUri)
{
return URI.create(service.getAuthorizationUrl(ImmutableMap.<String, String>builder()
.put(OAuthConstants.REDIRECT_URI, callbackUri.toString())
.put(OAuthConstants.STATE, state)
.build()));
ImmutableMap.Builder<String, String> parameters = ImmutableMap.builder();
parameters.put(REDIRECT_URI, callbackUri.toString());
parameters.put(STATE, state);
audience.ifPresent(audience -> parameters.put("audience", audience));
return URI.create(service.getAuthorizationUrl(parameters.build()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
import javax.ws.rs.core.Response;

import java.net.URI;
import java.util.List;
import java.util.Optional;

import static com.google.common.base.MoreObjects.firstNonNull;
import static io.jsonwebtoken.Claims.AUDIENCE;
import static io.trino.server.ServletSecurityUtils.sendErrorMessage;
import static io.trino.server.ServletSecurityUtils.sendWwwAuthenticate;
import static io.trino.server.ServletSecurityUtils.setAuthenticatedIdentity;
Expand All @@ -50,13 +52,15 @@ public class OAuth2WebUiAuthenticationFilter

private final OAuth2Service service;
private final UserMapping userMapping;
private final Optional<String> validAudience;

@Inject
public OAuth2WebUiAuthenticationFilter(OAuth2Service service, OAuth2Config oauth2Config)
{
this.service = requireNonNull(service, "service is null");
requireNonNull(oauth2Config, "oauth2Config is null");
this.userMapping = UserMapping.createUserMapping(oauth2Config.getUserMappingPattern(), oauth2Config.getUserMappingFile());
this.validAudience = oauth2Config.getAudience();
}

@Override
Expand All @@ -78,15 +82,21 @@ public void filter(ContainerRequestContext request)
request.abortWith(Response.seeOther(DISABLED_LOCATION_URI).build());
return;
}

Optional<String> subject = getAccessToken(request).map(token -> token.getBody().getSubject());
if (subject.isEmpty()) {
Optional<Claims> claims = getAccessToken(request).map(Jws::getBody);
if (claims.isEmpty()) {
needAuthentication(request);
return;
}
Object audience = claims.get().get(AUDIENCE);
if (!hasValidAudience(audience)) {
LOG.debug("Invalid audience: %s. Expected audience to be equal to or contain: %s", audience, validAudience);
sendErrorMessage(request, UNAUTHORIZED, "Unauthorized");
return;
}
try {
setAuthenticatedIdentity(request, Identity.forUser(userMapping.mapUser(subject.get()))
.withPrincipal(new BasicPrincipal(subject.get()))
String subject = claims.get().getSubject();
setAuthenticatedIdentity(request, Identity.forUser(userMapping.mapUser(subject))
.withPrincipal(new BasicPrincipal(subject))
.build());
}
catch (UserMappingException e) {
Expand All @@ -113,4 +123,21 @@ private void needAuthentication(ContainerRequestContext request)
URI redirectLocation = service.startChallenge(request.getUriInfo().getBaseUri().resolve(CALLBACK_ENDPOINT));
request.abortWith(Response.seeOther(redirectLocation).build());
}

private boolean hasValidAudience(Object audience)
{
if (validAudience.isEmpty()) {
return true;
}
if (audience == null) {
return false;
}
if (audience instanceof String) {
return audience.equals(validAudience.get());
}
if (audience instanceof List) {
return ((List<?>) audience).contains(validAudience.get());
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public void testDefaults()
.setJwksUrl(null)
.setClientId(null)
.setClientSecret(null)
.setAudience(null)
.setChallengeTimeout(Duration.valueOf("15m"))
.setUserMappingPattern(null)
.setUserMappingFile(null));
Expand All @@ -55,6 +56,7 @@ public void testExplicitPropertyMappings()
.put("http-server.authentication.oauth2.jwks-url", "http://127.0.0.1:9000/.well-known/jwks.json")
.put("http-server.authentication.oauth2.client-id", "another-consumer")
.put("http-server.authentication.oauth2.client-secret", "consumer-secret")
.put("http-server.authentication.oauth2.audience", "https://127.0.0.1:8443")
.put("http-server.authentication.oauth2.challenge-timeout", "90s")
.put("http-server.authentication.oauth2.user-mapping.pattern", "(.*)@something")
.put("http-server.authentication.oauth2.user-mapping.file", userMappingFile.toString())
Expand All @@ -67,6 +69,7 @@ public void testExplicitPropertyMappings()
.setJwksUrl("http://127.0.0.1:9000/.well-known/jwks.json")
.setClientId("another-consumer")
.setClientSecret("consumer-secret")
.setAudience("https://127.0.0.1:8443")
.setChallengeTimeout(Duration.valueOf("90s"))
.setUserMappingPattern("(.*)@something")
.setUserMappingFile(userMappingFile.toFile());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,30 @@
import static io.airlift.testing.Assertions.assertLessThan;
import static io.trino.client.OkHttpUtil.setupInsecureSsl;
import static io.trino.server.security.oauth2.TestingHydraService.TTL_ACCESS_TOKEN_IN_SECONDS;
import static io.trino.server.security.oauth2.TokenEndpointAuthMethod.CLIENT_SECRET_BASIC;
import static io.trino.server.security.oauth2.TokenEndpointAuthMethod.CLIENT_SECRET_POST;
import static io.trino.server.ui.OAuthWebUiCookie.OAUTH2_COOKIE;
import static java.lang.String.format;
import static javax.ws.rs.core.HttpHeaders.LOCATION;
import static javax.ws.rs.core.Response.Status.OK;
import static javax.ws.rs.core.Response.Status.SEE_OTHER;
import static javax.ws.rs.core.Response.Status.UNAUTHORIZED;
import static org.assertj.core.api.Assertions.assertThat;
import static org.openqa.selenium.support.ui.ExpectedConditions.elementToBeClickable;

@Test(singleThreaded = true)
public class TestOAuth2WebUiAuthenticationFilter
{
private static final int HTTPS_PORT = findAvailablePort();
private static final String EXPOSED_SERVER_URL = format("https://host.testcontainers.internal:%d", HTTPS_PORT);
private static final String TRINO_CLIENT_ID = "trino-client";
private static final String TRINO_CLIENT_SECRET = "trino-secret";
private static final String TRINO_AUDIENCE = EXPOSED_SERVER_URL + "/ui";
private static final String TRUSTED_CLIENT_ID = "trusted-client";
private static final String TRUSTED_CLIENT_SECRET = "trusted-secret";
private static final String UNTRUSTED_CLIENT_ID = "untrusted-client";
private static final String UNTRUSTED_CLIENT_SECRET = "untrusted-secret";
private static final String UNTRUSTED_CLIENT_AUDIENCE = "https://untrusted.com";

private final TestingHydraService testingHydraService = new TestingHydraService();
private final OkHttpClient httpClient;
Expand All @@ -99,7 +111,24 @@ public void setup()

Testcontainers.exposeHostPorts(HTTPS_PORT);
testingHydraService.start();
testingHydraService.createConsumer(format("https://host.testcontainers.internal:%d/oauth2/callback", HTTPS_PORT));
testingHydraService.createClient(
TRINO_CLIENT_ID,
TRINO_CLIENT_SECRET,
CLIENT_SECRET_BASIC,
TRINO_AUDIENCE,
EXPOSED_SERVER_URL + "/oauth2/callback");
testingHydraService.createClient(
TRUSTED_CLIENT_ID,
TRUSTED_CLIENT_SECRET,
CLIENT_SECRET_POST,
TRINO_AUDIENCE,
EXPOSED_SERVER_URL + "/oauth2/callback");
testingHydraService.createClient(
UNTRUSTED_CLIENT_ID,
UNTRUSTED_CLIENT_SECRET,
CLIENT_SECRET_POST,
UNTRUSTED_CLIENT_AUDIENCE,
"https://untrusted.com/callback");

server = TestingTrinoServer.builder()
.setCoordinator(true)
Expand All @@ -114,9 +143,10 @@ public void setup()
.put("http-server.authentication.oauth2.auth-url", "http://hydra:4444/oauth2/auth")
.put("http-server.authentication.oauth2.token-url", format("http://localhost:%s/oauth2/token", testingHydraService.getHydraPort()))
.put("http-server.authentication.oauth2.jwks-url", format("http://localhost:%s/.well-known/jwks.json", testingHydraService.getHydraPort()))
.put("http-server.authentication.oauth2.client-id", "another-consumer")
.put("http-server.authentication.oauth2.client-secret", "consumer-secret")
.put("http-server.authentication.oauth2.user-mapping.pattern", "(.*)@.*")
.put("http-server.authentication.oauth2.client-id", TRINO_CLIENT_ID)
.put("http-server.authentication.oauth2.client-secret", TRINO_CLIENT_SECRET)
.put("http-server.authentication.oauth2.audience", format("https://host.testcontainers.internal:%d/ui", HTTPS_PORT))
.put("http-server.authentication.oauth2.user-mapping.pattern", "(.*)(@.*)?")
.build())
.build();
server.waitForNodeRefresh(Duration.ofSeconds(10));
Expand All @@ -137,7 +167,7 @@ public void testUnauthorizedApiCall()
try (Response response = httpClient
.newCall(uiCall().build())
.execute()) {
assertUnauthorizedUICall(response);
assertRedirectResponse(response);
}
}

Expand All @@ -148,12 +178,12 @@ public void testUnauthorizedUICall()
try (Response response = httpClient
.newCall(uiCall().build())
.execute()) {
assertUnauthorizedUICall(response);
assertRedirectResponse(response);
}
}

@Test
public void testInvalidToken()
public void testUnsignedToken()
throws NoSuchAlgorithmException, IOException
{
KeyPairGenerator keyGenerator = KeyPairGenerator.getInstance("RSA");
Expand All @@ -168,7 +198,7 @@ public void testInvalidToken()
new DefaultClaims(
ImmutableMap.<String, Object>builder()
.put("aud", ImmutableList.of())
.put("client_id", "another-consumer")
.put("client_id", TRINO_CLIENT_ID)
.put("exp", now + 60L)
.put("iat", now)
.put("iss", "http://hydra:4444/")
Expand All @@ -183,10 +213,30 @@ public void testInvalidToken()
try (Response response = httpClientUsingCookie(new Cookie.Builder(OAUTH2_COOKIE, token).build())
.newCall(uiCall().build())
.execute()) {
assertUnauthorizedUICall(response);
assertRedirectResponse(response);
}
}

@Test
public void testTokenWithInvalidAudience()
throws IOException
{
String token = testingHydraService.getToken(UNTRUSTED_CLIENT_ID, UNTRUSTED_CLIENT_SECRET, UNTRUSTED_CLIENT_AUDIENCE);
try (Response response = httpClientUsingCookie(new Cookie.Builder(OAUTH2_COOKIE, token).build())
.newCall(uiCall().build())
.execute()) {
assertUnauthorizedResponse(response);
}
}

@Test
public void testTokenFromTrustedClient()
throws IOException
{
String token = testingHydraService.getToken(TRUSTED_CLIENT_ID, TRUSTED_CLIENT_SECRET, TRINO_AUDIENCE);
assertUICallWithCookie(new Cookie.Builder(OAUTH2_COOKIE, token).build());
}

@Test
@Flaky(issue = "https://github.com/trinodb/trino/issues/6223", match = OAUTH2_COOKIE + " is missing")
public void testSuccessfulFlow()
Expand All @@ -209,7 +259,7 @@ public void testExpiredAccessToken()
assertThat(cookie).withFailMessage(OAUTH2_COOKIE + " is missing").isNotNull();
Thread.sleep((TTL_ACCESS_TOKEN_IN_SECONDS + 1) * 1000L); // wait for the token expiration
try (Response response = httpClientUsingCookie(cookie).newCall(uiCall().build()).execute()) {
assertUnauthorizedUICall(response);
assertRedirectResponse(response);
}
}));
}
Expand Down Expand Up @@ -291,7 +341,7 @@ private void assertTrinoCookie(Cookie cookie)
private void assertAccessToken(Jws<Claims> jwt)
{
assertThat(jwt.getBody().getSubject()).isEqualTo("foo@bar.com");
assertThat(jwt.getBody().get("client_id")).isEqualTo("another-consumer");
assertThat(jwt.getBody().get("client_id")).isEqualTo(TRINO_CLIENT_ID);
assertThat(jwt.getBody().getIssuer()).isEqualTo("http://hydra:4444/");
}

Expand Down Expand Up @@ -344,13 +394,21 @@ private static int findAvailablePort()
}
}

private static void assertUnauthorizedUICall(Response response)
private static void assertRedirectResponse(Response response)
throws MalformedURLException
{
assertThat(response.code()).isEqualTo(SEE_OTHER.getStatusCode());
assertRedirectUrl(response.header(LOCATION));
}

private static void assertUnauthorizedResponse(Response response)
throws IOException
{
assertThat(response.code()).isEqualTo(UNAUTHORIZED.getStatusCode());
assertThat(response.body()).isNotNull();
assertThat(response.body().string()).isEqualTo("Unauthorized");
}

private static void assertRedirectUrl(String redirectUrl)
throws MalformedURLException
{
Expand All @@ -365,7 +423,7 @@ private static void assertRedirectUrl(String redirectUrl)
assertThat(url.queryParameterValues("response_type")).isEqualTo(ImmutableList.of("code"));
assertThat(url.queryParameterValues("scope")).isEqualTo(ImmutableList.of("openid"));
assertThat(url.queryParameterValues("redirect_uri")).isEqualTo(ImmutableList.of(format("https://127.0.0.1:%s/oauth2/callback", HTTPS_PORT)));
assertThat(url.queryParameterValues("client_id")).isEqualTo(ImmutableList.of("another-consumer"));
assertThat(url.queryParameterValues("client_id")).isEqualTo(ImmutableList.of(TRINO_CLIENT_ID));
assertThat(url.queryParameterValues("state")).isNotNull();
}

Expand Down
Loading

0 comments on commit 1b68565

Please sign in to comment.