Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ language governing permissions and limitations under the License. -->
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-ssl-context-service-api</artifactId>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-dbcp-service-api</artifactId>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-distributed-cache-client-service</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.apache.nifi.cdc.mysql.processors.ssl.ConnectionPropertiesProvider;
import org.apache.nifi.cdc.mysql.processors.ssl.StandardConnectionPropertiesProvider;
import org.apache.nifi.components.AllowableValue;
import org.apache.nifi.components.DescribedValue;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.components.PropertyValue;
import org.apache.nifi.components.ValidationContext;
Expand All @@ -72,6 +73,8 @@
import org.apache.nifi.components.state.Scope;
import org.apache.nifi.components.state.StateManager;
import org.apache.nifi.components.state.StateMap;
import org.apache.nifi.dbcp.api.DatabasePasswordProvider;
import org.apache.nifi.dbcp.api.DatabasePasswordRequestContext;
import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.logging.ComponentLog;
Expand Down Expand Up @@ -99,6 +102,7 @@
import java.sql.SQLFeatureNotSupportedException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
Expand All @@ -109,6 +113,7 @@
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import javax.net.ssl.SSLContext;
Expand Down Expand Up @@ -156,6 +161,8 @@ public class CaptureChangeMySQL extends AbstractSessionFactoryProcessor {

private static final int DEFAULT_MYSQL_PORT = 3306;

private static final String JDBC_URL_FORMAT = "jdbc:mysql://%s";

// A regular expression matching multiline comments, used when parsing DDL statements
private static final Pattern MULTI_COMMENT_PATTERN = Pattern.compile("/\\*.*?\\*/", Pattern.DOTALL);

Expand Down Expand Up @@ -258,13 +265,58 @@ public class CaptureChangeMySQL extends AbstractSessionFactoryProcessor {
.expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT)
.build();

public enum PasswordSource implements DescribedValue {
PASSWORD("Password", "Use the configured Password property for database authentication."),
PASSWORD_PROVIDER("Password Provider", "Obtain database passwords from a configured Database Password Provider.");

private final String displayName;
private final String description;

PasswordSource(final String displayName, final String description) {
this.displayName = displayName;
this.description = description;
}

@Override
public String getDisplayName() {
return displayName;
}

@Override
public String getValue() {
return name();
}

@Override
public String getDescription() {
return description;
}
}

public static final PropertyDescriptor PASSWORD_SOURCE = new PropertyDescriptor.Builder()
.name("Password Source")
.description("Specifies whether to supply the database password directly or obtain it from a Database Password Provider.")
.allowableValues(PasswordSource.class)
.defaultValue(PasswordSource.PASSWORD)
.required(true)
.build();

public static final PropertyDescriptor PASSWORD = new PropertyDescriptor.Builder()
.name("Password")
.description("Password to access the MySQL cluster")
.required(false)
.sensitive(true)
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
.expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT)
.dependsOn(PASSWORD_SOURCE, PasswordSource.PASSWORD)
.build();

public static final PropertyDescriptor DB_PASSWORD_PROVIDER = new PropertyDescriptor.Builder()
.name("Database Password Provider")
.description("Controller Service that supplies database passwords on demand. When configured, the Password property is ignored.")
.required(true)
.identifiesControllerService(DatabasePasswordProvider.class)
.dependsOn(PASSWORD_SOURCE, PasswordSource.PASSWORD_PROVIDER)
.build();

public static final PropertyDescriptor EVENTS_PER_FLOWFILE_STRATEGY = new PropertyDescriptor.Builder()
Expand Down Expand Up @@ -419,7 +471,9 @@ public class CaptureChangeMySQL extends AbstractSessionFactoryProcessor {
DRIVER_NAME,
DRIVER_LOCATION,
USERNAME,
PASSWORD_SOURCE,
PASSWORD,
DB_PASSWORD_PROVIDER,
EVENTS_PER_FLOWFILE_STRATEGY,
NUMBER_OF_EVENTS_PER_FLOWFILE,
SERVER_ID,
Expand All @@ -443,6 +497,10 @@ public class CaptureChangeMySQL extends AbstractSessionFactoryProcessor {
private volatile BinlogLifecycleListener lifecycleListener;
private volatile GtidSet gtidSet;

private volatile DatabasePasswordProvider passwordProvider;
private volatile DatabasePasswordRequestContext passwordRequestContext;
private volatile String password;

// Set queue capacity to avoid excessive memory consumption
private final BlockingQueue<RawBinlogEvent> queue = new LinkedBlockingQueue<>(1000);

Expand Down Expand Up @@ -635,17 +693,30 @@ public void setup(ProcessContext context) {
final SSLMode sslMode = SSLMode.valueOf(context.getProperty(SSL_MODE).getValue());
final SSLContextService sslContextService = sslMode == SSLMode.DISABLED ? null : context.getProperty(SSL_CONTEXT_SERVICE).asControllerService(SSLContextService.class);

final PasswordSource passwordSource = context.getProperty(PASSWORD_SOURCE).asAllowableValue(PasswordSource.class);
switch (passwordSource) {
case PASSWORD -> {
passwordProvider = null;
passwordRequestContext = null;
password = StringUtils.defaultString(context.getProperty(PASSWORD).evaluateAttributeExpressions().getValue());
}
case PASSWORD_PROVIDER -> {
password = null;
passwordProvider = context.getProperty(DB_PASSWORD_PROVIDER).asControllerService(DatabasePasswordProvider.class);
passwordRequestContext = DatabasePasswordRequestContext.builder()
.jdbcUrl(JDBC_URL_FORMAT.formatted(context.getProperty(HOSTS).evaluateAttributeExpressions().getValue()))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This request context jdbcUrl is built from the raw MySQL Nodes value, so with multiple hosts it becomes jdbc:mysql://h1:3306,h2:3306. Should we pass the actual connected host so providers that sign per host (RDS IAM) work without needing the optional endpoint override?

.driverClassName(context.getProperty(DRIVER_NAME).evaluateAttributeExpressions().getValue())
.databaseUser(context.getProperty(USERNAME).evaluateAttributeExpressions().getValue())
.build();
}
}

// Save off MySQL cluster and JDBC driver information, will be used to connect for event enrichment as well as for the binlog connector
try {
List<InetSocketAddress> hosts = getHosts(context.getProperty(HOSTS).evaluateAttributeExpressions().getValue());

String username = context.getProperty(USERNAME).evaluateAttributeExpressions().getValue();
String password = context.getProperty(PASSWORD).evaluateAttributeExpressions().getValue();

// BinaryLogClient expects a non-null password, so set it to the empty string if it is not provided
if (password == null) {
password = "";
}
String resolvedPassword = resolvePassword();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The binlog client password is resolved only once at setup, while the JDBC connection re-fetches it on each connection. For a provider that returns short lived tokens (like RDS IAM), should the binlog connection also refresh the password when it reconnects?


long connectTimeout = context.getProperty(CONNECT_TIMEOUT).evaluateAttributeExpressions().asTimePeriod(TimeUnit.MILLISECONDS);

Expand All @@ -654,7 +725,7 @@ public void setup(ProcessContext context) {

Long serverId = context.getProperty(SERVER_ID).evaluateAttributeExpressions().asLong();

connect(hosts, username, password, serverId, driverLocation, driverName, connectTimeout, sslContextService, sslMode);
connect(hosts, username, resolvedPassword, serverId, driverLocation, driverName, connectTimeout, sslContextService, sslMode);
} catch (IOException | IllegalStateException e) {
if (eventListener != null) {
eventListener.stop();
Expand Down Expand Up @@ -812,10 +883,7 @@ protected void connect(List<InetSocketAddress> hosts, String username, String pa
}

try {
if (connectTimeout == 0) {
connectTimeout = Long.MAX_VALUE;
}
binlogClient.connect(connectTimeout);
binlogClient.connect(connectTimeout == 0 ? Long.MAX_VALUE : connectTimeout);
binlogResourceInfo.setTransitUri("mysql://" + connectedHost.getHostString() + ":" + connectedHost.getPort());

} catch (IOException | TimeoutException te) {
Expand All @@ -842,11 +910,21 @@ protected void connect(List<InetSocketAddress> hosts, String username, String pa
final TlsConfiguration tlsConfiguration = sslContextService == null ? null : sslContextService.createTlsConfiguration();
final ConnectionPropertiesProvider connectionPropertiesProvider = new StandardConnectionPropertiesProvider(sslMode, tlsConfiguration);
final Map<String, String> jdbcConnectionProperties = connectionPropertiesProvider.getConnectionProperties();
jdbcConnectionHolder = new JDBCConnectionHolder(connectedHost, username, password, jdbcConnectionProperties, connectTimeout);

if (passwordProvider != null && passwordRequestContext != null) {
passwordRequestContext = DatabasePasswordRequestContext.builder()
.jdbcUrl(passwordRequestContext.getJdbcUrl())
.driverClassName(passwordRequestContext.getDriverClassName())
.databaseUser(passwordRequestContext.getDatabaseUser())
.connectionProperties(jdbcConnectionProperties)
.build();
}

jdbcConnectionHolder = new JDBCConnectionHolder(connectedHost, username, this::resolvePassword, jdbcConnectionProperties, connectTimeout);
try {
// Ensure connection can be created.
getJdbcConnection();
} catch (SQLException e) {
} catch (SQLException | ProcessException e) {
getLogger().error("Error creating binlog enrichment JDBC connection to any of the specified hosts", e);
if (eventListener != null) {
eventListener.stop();
Expand Down Expand Up @@ -1157,6 +1235,10 @@ public void stop() throws CDCException {
if (jdbcConnectionHolder != null) {
jdbcConnectionHolder.close();
}

password = null;
passwordProvider = null;
passwordRequestContext = null;
}
}

Expand Down Expand Up @@ -1211,6 +1293,19 @@ protected BinaryLogClient createBinlogClient(String hostname, int port, String u
return new BinaryLogClient(hostname, port, username, password);
}

private String resolvePassword() {
if (passwordProvider == null) {
return StringUtils.defaultString(password);
}
final char[] passwordChars = passwordProvider.getPassword(passwordRequestContext);
if (passwordChars == null || passwordChars.length == 0) {
throw new ProcessException("Database Password Provider returned an empty password");
}
final String resolvedPassword = new String(passwordChars);
Arrays.fill(passwordChars, '\0');
return resolvedPassword;
}

/**
* Retrieves the column information for the specified database and table. The column information can be used to enrich CDC events coming from the RDBMS.
*
Expand Down Expand Up @@ -1253,18 +1348,17 @@ protected Connection getJdbcConnection() throws SQLException {
private class JDBCConnectionHolder {
private final String connectionUrl;
private final Properties connectionProps = new Properties();
private final Supplier<String> passwordSupplier;
private final long connectionTimeoutMillis;

private Connection connection;

private JDBCConnectionHolder(InetSocketAddress host, String username, String password, Map<String, String> customProperties, long connectionTimeoutMillis) {
this.connectionUrl = "jdbc:mysql://" + host.getHostString() + ":" + host.getPort();
private JDBCConnectionHolder(InetSocketAddress host, String username, Supplier<String> passwordSupplier, Map<String, String> customProperties, long connectionTimeoutMillis) {
this.connectionUrl = JDBC_URL_FORMAT.formatted(host.getHostString() + ":" + host.getPort());
this.passwordSupplier = passwordSupplier;
connectionProps.putAll(customProperties);
if (username != null) {
connectionProps.put("user", username);
if (password != null) {
connectionProps.put("password", password);
}
}

this.connectionTimeoutMillis = connectionTimeoutMillis;
Expand All @@ -1280,7 +1374,13 @@ private Connection getConnection() throws SQLException {
close();

getLogger().trace("Creating a new JDBC connection.");
connection = DriverManager.getConnection(connectionUrl, connectionProps);
final Properties props = new Properties();
props.putAll(connectionProps);
final String password = passwordSupplier.get();
if (password != null) {
props.put("password", password);
}
connection = DriverManager.getConnection(connectionUrl, props);
return connection;
}

Expand Down
Loading
Loading