Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.texera.auth

import org.apache.texera.common.config.AuthConfig
import org.apache.texera.dao.jooq.generated.enums.UserRoleEnum
import org.apache.texera.dao.jooq.generated.tables.pojos.User
import org.jose4j.jwt.NumericDate
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class JwtAuthSpec extends AnyFlatSpec with Matchers {

private def buildUser(): User = {
val user = new User()
user.setUid(42)
user.setName("alice")
user.setEmail("alice@example.com")
user.setGoogleId("g-123")
user.setGoogleAvatar("avatar-blob")
user.setRole(UserRoleEnum.ADMIN)
user
}

"JwtAuth.jwtClaims" should "map every User field onto the matching claim" in {
val claims = JwtAuth.jwtClaims(buildUser(), 7)
claims.getSubject shouldBe "alice"
claims.getClaimValueAsString("userId") shouldBe "42"
claims.getClaimValueAsString("googleId") shouldBe "g-123"
claims.getClaimValueAsString("email") shouldBe "alice@example.com"
claims.getClaimValueAsString("googleAvatar") shouldBe "avatar-blob"
claims.getClaimValueAsString("role") shouldBe UserRoleEnum.ADMIN.name
}

it should "derive the expiration from config, ignoring the expireInDays argument" in {
// two very different expireInDays values must yield the same config-derived expiry window
def expiryWindowMinutes(expireInDays: Int): Double = {
val claims = JwtAuth.jwtClaims(buildUser(), expireInDays)
claims.getExpirationTime should not be null
claims.getExpirationTime.getValue / 60.0 - NumericDate.now().getValue / 60.0
}
expiryWindowMinutes(1) shouldBe (AuthConfig.jwtExpirationMinutes.toDouble +- 2.0)
expiryWindowMinutes(100000) shouldBe (AuthConfig.jwtExpirationMinutes.toDouble +- 2.0)
}

it should "produce a token that round-trips back to the same user via JwtParser" in {
val token = JwtAuth.jwtToken(JwtAuth.jwtClaims(buildUser(), 1))
val parsed = JwtParser.parseToken(token)
parsed.isPresent shouldBe true
val user = parsed.get().getUser
user.getUid shouldBe 42
user.getName shouldBe "alice"
user.getEmail shouldBe "alice@example.com"
user.getGoogleId shouldBe "g-123"
user.getGoogleAvatar shouldBe "avatar-blob"
user.getRole shouldBe UserRoleEnum.ADMIN
}

it should "carry through null optional fields without error" in {
val user = new User()
user.setUid(7)
user.setName("bob")
user.setRole(UserRoleEnum.ADMIN)
val claims = JwtAuth.jwtClaims(user, 1)
claims.getSubject shouldBe "bob"
claims.getClaimValueAsString("email") shouldBe null
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.texera.auth

import ch.qos.logback.classic.{Level, Logger => LogbackLogger}
import jakarta.servlet.{DispatcherType, FilterChain}
import jakarta.servlet.http.{HttpServletRequest, HttpServletResponse}
import org.eclipse.jetty.servlet.{FilterHolder, ServletContextHandler}
import org.mockito.ArgumentMatchers.{any, eq => eqTo}
import org.mockito.Mockito
import org.mockito.Mockito.{mock, verify, when}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.slf4j.LoggerFactory

class RequestLoggingFilterSpec extends AnyFlatSpec with Matchers {

"RequestLoggingFilter.doFilter" should "delegate to the chain before logging the request" in {
val filter = new RequestLoggingFilter
val request = mock(classOf[HttpServletRequest])
val response = mock(classOf[HttpServletResponse])
val chain = mock(classOf[FilterChain])
when(request.getRemoteAddr).thenReturn("1.2.3.4")
when(request.getMethod).thenReturn("GET")
when(request.getRequestURI).thenReturn("/api/x")
when(request.getProtocol).thenReturn("HTTP/1.1")
when(response.getStatus).thenReturn(200)

// force the request-log logger to INFO so the log branch (and its getter reads) runs
val requestLog =
LoggerFactory.getLogger("org.eclipse.jetty.server.RequestLog").asInstanceOf[LogbackLogger]
val previousLevel = requestLog.getLevel
requestLog.setLevel(Level.INFO)
try {
filter.doFilter(request, response, chain)
} finally {
requestLog.setLevel(previousLevel)
}

// the chain is invoked, and only afterward are the request fields read for the log line
// (Mockito.inOrder, fully qualified to avoid ScalaTest Matchers' own inOrder DSL)
val ordered = Mockito.inOrder(chain, request)
ordered.verify(chain).doFilter(request, response)
ordered.verify(request).getRemoteAddr
verify(request).getMethod
verify(request).getRequestURI
verify(request).getProtocol
verify(response).getStatus
}

"RequestLoggingFilter.register" should "add the filter to the servlet context for all dispatch types" in {
val context = mock(classOf[ServletContextHandler])
RequestLoggingFilter.register(context)
verify(context).addFilter(
any(classOf[FilterHolder]),
eqTo("/*"),
eqTo(java.util.EnumSet.allOf(classOf[DispatcherType]))
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,32 @@ class UserActivityTrackerSpec extends AnyFlatSpec with Matchers {
// Must not throw — the wrapper catches NonFatal from upsertFn.
noException should be thrownBy tracker.markActive(42)
}

it should "swallow exceptions thrown before the write is dispatched" in {
val recorder = new Recorder
// a clock that throws forces the failure in markActive before executor.execute
val tracker =
new UserActivityTracker(
Duration.ofMinutes(5),
recorder.upsert,
sameThread,
() => throw new RuntimeException("clock boom")
)

noException should be thrownBy tracker.markActive(7)
recorder.calls.size shouldBe 0 // the write was never dispatched
}

it should "swallow exceptions thrown by evictStale" in {
val recorder = new Recorder
val tracker =
new UserActivityTracker(
Duration.ofMinutes(5),
recorder.upsert,
sameThread,
() => throw new RuntimeException("clock boom")
)

noException should be thrownBy tracker.evictStale()
}
}
Loading