gRPC: From Node Js, How to send Array of float by using repeated bytes of protobuff to python - grpc

I would like to send a list of float list via nodejs and receive it in python using protobuff's repeated bytes type.
The graph helps to understand the problem:
I tried with this configuration and what I get on the python side is not really what I expect:
tensors=[b'-TWW', b'-TWW', b'-TWW', b'-TWW']
Here is my test in node.
Client :
const PROTO_PATH = __dirname + '/route_guide.proto';
const async = require('async');
const grpc = require('#grpc/grpc-js');
const protoLoader = require('#grpc/proto-loader');
const packageDefinition = protoLoader.loadSync(
PROTO_PATH,
{
keepCase: true,
longs: String,
enums: String,
defaults: true,
oneofs: true
});
const routeguide = grpc.loadPackageDefinition(packageDefinition).routeguide;
const client = new routeguide.RouteGuide('localhost:50051',
grpc.credentials.createInsecure());
function runJoin(callback) {
const call = client.join();
call.on('data', function(receivedMessage) {
console.log('Got message "' + JSON.stringify(receivedMessage));
});
call.on('end', callback);
messageToSend = {
msg: 'parameters_res',
parameters_res: {
parameters: {
tensors: [
new Buffer.from(new Float64Array([45.1]).buffer),
new Buffer.from(new Float64Array([45.1, 84.5, 87.9, 87.1]).buffer),
new Buffer.from(new Float64Array([45.1, 84.5, 87.9, 87.1]).buffer),
new Buffer.from(new Float64Array([45.1, 84.5, 87.9, 87.1]).buffer)
],
tensor_type: 'numpy.ndarray'
}
}
}
console.log(messageToSend);
console.log(messageToSend.parameters_res.parameters.tensors)
call.write(messageToSend);
call.end();
}
function main() {
async.series([
runJoin
]);
}
if (require.main === module) {
main();
}
exports.runJoin = runJoin;
route_guide.proto:
syntax = "proto3";
option java_multiple_files = true;
option java_package = "io.grpc.examples.routeguide";
option java_outer_classname = "RouteGuideProto";
option objc_class_prefix = "RTG";
package routeguide;
service RouteGuide {
rpc Join(stream ClientMessage) returns (stream ClientMessage) {}
}
message RouteNote {
repeated bytes model = 1;
}
message ClientMessage {
message Disconnect { Reason reason = 1; }
message ParametersRes { Parameters parameters = 1; }
oneof msg {
Disconnect disconnect = 1;
ParametersRes parameters_res = 2;
}
}
message Parameters {
repeated bytes tensors = 1;
string tensor_type = 2;
}
enum Reason {
UNKNOWN = 0;
RECONNECT = 1;
POWER_DISCONNECTED = 2;
WIFI_UNAVAILABLE = 3;
ACK = 4;
}
Server:
const PROTO_PATH = __dirname + '/route_guide.proto';
const grpc = require('#grpc/grpc-js');
const protoLoader = require('#grpc/proto-loader');
const packageDefinition = protoLoader.loadSync(
PROTO_PATH,
{keepCase: true,
longs: String,
enums: String,
defaults: true,
oneofs: true
});
const routeguide = grpc.loadPackageDefinition(packageDefinition).routeguide;
function join(call) {
call.on('data', function(receivedMessage) {
console.log("SERVER RECEIVE:");
console.log(receivedMessage);
console.log(receivedMessage.parameters_res.parameters.tensors)
for (const element of receivedMessage.parameters_res.parameters.tensors) {
console.log(element)
}
call.write(receivedMessage);
});
call.on('end', function() {
call.end();
});
}
function getServer() {
var server = new grpc.Server();
server.addService(routeguide.RouteGuide.service, {
join: join
});
return server;
}
if (require.main === module) {
var routeServer = getServer();
routeServer.bindAsync('0.0.0.0:50051', grpc.ServerCredentials.createInsecure(), () => {
routeServer.start()
});
}
exports.getServer = getServer;
MyStartegy.py:
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, cast
import numpy as np
import flwr as fl
from flwr.common import (
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
Parameters,
Scalar,
Weights,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg
from flwr.server.strategy import Strategy
from tensorflow import Tensor
DEPRECATION_WARNING = """
DEPRECATION WARNING: deprecated `eval_fn` return format
loss, accuracy
move to
loss, {"accuracy": accuracy}
instead. Note that compatibility with the deprecated return format will be
removed in a future release.
"""
DEPRECATION_WARNING_INITIAL_PARAMETERS = """
DEPRECATION WARNING: deprecated initial parameter type
flwr.common.Weights (i.e., List[np.ndarray])
will be removed in a future update, move to
flwr.common.Parameters
instead. Use
parameters = flwr.common.weights_to_parameters(weights)
to easily transform `Weights` to `Parameters`.
"""
class MyStrategy(Strategy):
"""Configurable FedAvg strategy implementation."""
# pylint: disable=too-many-arguments,too-many-instance-attributes
def __init__(
self,
fraction_fit: float = 0.1,
fraction_eval: float = 0.1,
min_fit_clients: int = 2,
min_eval_clients: int = 2,
min_available_clients: int = 2,
eval_fn: Optional[
Callable[[Weights], Optional[Tuple[float, Dict[str, Scalar]]]]
] = None,
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Parameters] = None,
) -> None:
"""Federated Averaging strategy.
Implementation based on https://arxiv.org/abs/1602.05629
Args:
fraction_fit (float, optional): Fraction of clients used during
training. Defaults to 0.1.
fraction_eval (float, optional): Fraction of clients used during
validation. Defaults to 0.1.
min_fit_clients (int, optional): Minimum number of clients used
during training. Defaults to 2.
min_eval_clients (int, optional): Minimum number of clients used
during validation. Defaults to 2.
min_available_clients (int, optional): Minimum number of total
clients in the system. Defaults to 2.
eval_fn (Callable[[Weights], Optional[Tuple[float, float]]], optional):
Function used for validation. Defaults to None.
on_fit_config_fn (Callable[[int], Dict[str, Scalar]], optional):
Function used to configure training. Defaults to None.
on_evaluate_config_fn (Callable[[int], Dict[str, Scalar]], optional):
Function used to configure validation. Defaults to None.
accept_failures (bool, optional): Whether or not accept rounds
containing failures. Defaults to True.
initial_parameters (Parameters, optional): Initial global model parameters.
"""
super().__init__()
self.min_fit_clients = min_fit_clients
self.min_eval_clients = min_eval_clients
self.fraction_fit = fraction_fit
self.fraction_eval = fraction_eval
self.min_available_clients = min_available_clients
self.eval_fn = eval_fn
self.on_fit_config_fn = on_fit_config_fn
self.on_evaluate_config_fn = on_evaluate_config_fn
self.accept_failures = accept_failures
self.initial_parameters = initial_parameters
def __repr__(self) -> str:
rep = f"FedAvg(accept_failures={self.accept_failures})"
return rep
def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Return the sample size and the required number of available
clients."""
num_clients = int(num_available_clients * self.fraction_fit)
return max(num_clients, self.min_fit_clients), self.min_available_clients
def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Use a fraction of available clients for evaluation."""
num_clients = int(num_available_clients * self.fraction_eval)
return max(num_clients, self.min_eval_clients), self.min_available_clients
def initialize_parameters(
self, client_manager: ClientManager
) -> Optional[Parameters]:
"""Initialize global model parameters."""
initial_parameters = self.initial_parameters
self.initial_parameters = None # Don't keep initial parameters in memory
if isinstance(initial_parameters, list):
log(WARNING, DEPRECATION_WARNING_INITIAL_PARAMETERS)
initial_parameters = self.weights_to_parameters(weights=initial_parameters)
return initial_parameters
def evaluate(
self, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate model parameters using an evaluation function."""
if self.eval_fn is None:
# No evaluation function provided
return None
weights = self.parameters_to_weights(parameters)
eval_res = self.eval_fn(weights)
if eval_res is None:
return None
loss, other = eval_res
if isinstance(other, float):
print(DEPRECATION_WARNING)
metrics = {"accuracy": other}
else:
metrics = other
return loss, metrics
def configure_fit(
self, rnd: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure the next round of training."""
config = {}
if self.on_fit_config_fn is not None:
# Custom fit config function provided
config = self.on_fit_config_fn(rnd)
fit_ins = FitIns(parameters, config)
# Sample clients
sample_size, min_num_clients = self.num_fit_clients(
client_manager.num_available()
)
clients = client_manager.sample(
num_clients=sample_size, min_num_clients=min_num_clients
)
# Return client/config pairs
return [(client, fit_ins) for client in clients]
def configure_evaluate(
self, rnd: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, EvaluateIns]]:
"""Configure the next round of evaluation."""
# Do not configure federated evaluation if fraction_eval is 0
if self.fraction_eval == 0.0:
return []
# Parameters and config
config = {}
if self.on_evaluate_config_fn is not None:
# Custom evaluation config function provided
config = self.on_evaluate_config_fn(rnd)
evaluate_ins = EvaluateIns(parameters, config)
# Sample clients
if rnd >= 0:
sample_size, min_num_clients = self.num_evaluation_clients(
client_manager.num_available()
)
clients = client_manager.sample(
num_clients=sample_size, min_num_clients=min_num_clients
)
else:
clients = list(client_manager.all().values())
# Return client/config pairs
return [(client, evaluate_ins) for client in clients]
def aggregate_fit(
self,
rnd: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[BaseException],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Convert results
print("\n\n aggregate_fit")
print(results)
weights_results = [
(self.parameters_to_weights(fit_res.parameters), fit_res.num_examples)
for client, fit_res in results
]
print("weights_results")
print(weights_results)
return self.weights_to_parameters(aggregate(weights_results)), {}
def aggregate_evaluate(
self,
rnd: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[BaseException],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation losses using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
loss_aggregated = weighted_loss_avg(
[
(evaluate_res.num_examples, evaluate_res.loss)
for _, evaluate_res in results
]
)
return loss_aggregated, {}
def weights_to_parameters(self, weights: Weights) -> Parameters:
"""Convert NumPy weights to parameters object."""
print('weights_to_parameters')
print(weights)
tensors = [self.ndarray_to_bytes(ndarray) for ndarray in weights]
return Parameters(tensors=tensors, tensor_type="numpy.nda")
def parameters_to_weights(self, parameters: Parameters) -> Weights:
"""Convert parameters object to NumPy weights."""
print('parameters_to_weights')
print(parameters)
return [self.bytes_to_ndarray(tensor) for tensor in parameters.tensors]
# pylint: disable=R0201
def ndarray_to_bytes(self, ndarray: np.ndarray) -> bytes:
"""Serialize NumPy array to bytes."""
print('ndarray_to_bytes')
print(ndarray)
return None
# pylint: disable=R0201
def bytes_to_ndarray(self, tensor: bytes) -> np.ndarray:
"""Deserialize NumPy array from bytes."""
print('bytes_to_ndarray')
print(tensor)
return None
# Start Flower server for three rounds of federated learning
fl.server.start_server(
server_address='localhost:5006',
config={"num_rounds": 2},
strategy=MyStrategy()
)
Is Float64Array the right type?
What should I use on the python side to deserialize the data?
I specify that I cannot modify the proto.
Thank you in advance for your explanations.

Related

request params encryption - http API source airbyte connector

I am developing an HTTP API source using the airbyte CDK.
When I am using the request params method I am returning the following response:
return {'startIndex': 0, 'resultsPerPage': self.max_results_per_page,
'pubStartDate': "2022-07-30T13:57:21:000 UTC%2B03:00", 'pubEndDate': "2022-07-31T13:57:21:000 UTC%2B03:00"}
But for some reason, those fields are being encrypted and the params that are being passed to the API are:
pubStartDate: "2022-07-30T13%3A57%3A21%3A000+UTC%252B03%3A00"
pubEndDate: "2022-07-31T13%3A57%3A21%3A000+UTC%252B03%3A00"
Is there anything I am doing wrong?
How can I send those date time strings correctly - without them being encrypted?
The API I am using: https://nvd.nist.gov/developers/vulnerabilities
Thanks in advance, any help will be much appreciated!!
The full connector code:
class NvdVulnerabilitiesStream(HttpStream, ABC):
url_base = "https://services.nvd.nist.gov/"
max_results_per_page = 2000
def __init__(self, days_to_fetch: int):
auth = NoAuth()
super().__init__(authenticator=auth)
self.days_to_fetch = days_to_fetch
def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
if response.json()['resultsPerPage'] < self.max_results_per_page:
return None
return {"startIndex": response.json()['startIndex'] + self.max_results_per_page, "resultsPerPage": self.max_results_per_page}
def request_params(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, any] = None, next_page_token: Mapping[str, Any] = None
) -> MutableMapping[str, Any]:
if not next_page_token and self.days_to_fetch != -1:
end_time = datetime.now()
start_time = end_time - timedelta(days=self.days_to_fetch)
if not next_page_token:
return {'startIndex': 0, 'resultsPerPage': self.max_results_per_page,
'pubStartDate': start_time.strftime(DATETIME_SCHEME), 'pubEndDate': end_time.strftime(DATETIME_SCHEME)}
if not next_page_token:
return {"startIndex": 0, "resultsPerPage": self.max_results_per_page}
return {"startIndex": next_page_token['startIndex'], "resultsPerPage": next_page_token['resultsPerPage']}
def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
for vulnerability in response.json()["result"]["CVE_Items"]:
yield vulnerability
print(response.json()["startIndex"])
time.sleep(6)
class Vulnerabilities(NvdVulnerabilitiesStream):
primary_key = None
def __init__(self, days_to_fetch: int):
super().__init__(days_to_fetch)
def path(
self, stream_state: Mapping[str, Any] = None, stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None
) -> str:
return "rest/json/cves/1.0/"```

Access the contents of the row inserted into Dynamodb using Pynamodb save method

I have the below model for a my dynamodb table using pynamodb :
from pynamodb.models import Model
from pynamodb.attributes import (
UnicodeAttribute, UTCDateTimeAttribute, UnicodeSetAttribute, BooleanAttribute
)
class Reminders(Model):
"""Model class for the Reminders table."""
# Information on global secondary index for the table
# user_id (hash key) + reminder_id+reminder_title(sort key)
class Meta:
table_name = 'Reminders'
region = 'eu-central-1'
reminder_id = UnicodeAttribute(hash_key=True)
user_id = UnicodeAttribute(range_key=True)
reminder_title = UnicodeAttribute()
reminder_tags = UnicodeSetAttribute()
reminder_description = UnicodeAttribute()
reminder_frequency = UnicodeAttribute(default='Only once')
reminder_tasks = UnicodeSetAttribute(default=set())
reminder_expiration_date_time = UTCDateTimeAttribute(null=True)
reminder_title_reminder_id = UnicodeAttribute()
next_reminder_date_time = UTCDateTimeAttribute()
should_expire = BooleanAttribute()
When i want to create a new reminder i do it through the below code :
class DynamoBackend:
#staticmethod
def create_a_new_reminder(new_reminder: NewReminder) -> Dict[str, Any]:
"""Create a new reminder using pynamodb."""
new_reminder = models.Reminders(**new_reminder.dict())
return new_reminder.save()
In this case the NewReminder is an instance of pydantic base model like so :
class NewReminder(pydantic.BaseModel):
reminder_id: str
user_id: str
reminder_title: str
reminder_description: str
reminder_tags: Sequence[str]
reminder_frequency: str
should_expire: bool
reminder_expiration_date_time: Optional[datetime.datetime]
next_reminder_date_time: datetime.datetime
reminder_title_reminder_id: str
when i call the save method on the model object i receive the below response:
{
"ConsumedCapacity": {
"CapacityUnits": 2.0,
"TableName": "Reminders"
}
}
Now my question is the save method is directly being called by a lambda function which is in turn called by an API Gateway POST endpoint so ideally the response should be a 201 created and instead of returning the consumed capacity and table name , would be great if it returns the item inserted in the database. Below is my route code :
def create_a_new_reminder():
"""Creates a new reminder in the database."""
request_context = app.current_request.context
request_body = json.loads(app.current_request.raw_body.decode())
request_body["reminder_frequency"] = data_structures.ReminderFrequency[request_body["reminder_frequency"]]
reminder_details = data_structures.ReminderDetailsFromRequest.parse_obj(request_body)
user_details = data_structures.UserDetails(
user_name=request_context["authorizer"]["claims"]["cognito:username"],
user_email=request_context["authorizer"]["claims"]["email"]
)
reminder_id = str(uuid.uuid1())
new_reminder = data_structures.NewReminder(
reminder_id=reminder_id,
user_id=user_details.user_name,
reminder_title=reminder_details.reminder_title,
reminder_description=reminder_details.reminder_description,
reminder_tags=reminder_details.reminder_tags,
reminder_frequency=reminder_details.reminder_frequency.value[0],
should_expire=reminder_details.should_expire,
reminder_expiration_date_time=reminder_details.reminder_expiration_date_time,
next_reminder_date_time=reminder_details.next_reminder_date_time,
reminder_title_reminder_id=f"{reminder_details.reminder_title}-{reminder_id}"
)
return DynamoBackend.create_a_new_reminder(new_reminder=new_reminder)
I am very new to REST API creation and best practices so would be great if someone would guide me here . Thanks in advance !

SQLAlchemy 1.4.0b1 AsyncSession issue

I'm using SQLAlchemy 1.4.0b1's AsyncSession to update a Postgres db with asyncpg 0.21.0. The code below aims to update objects and add new objects in response to various incoming Redis stream messages
The save_revised coroutine (update) is working fine, and so is the session.add part of the td_move coroutine. However the update part of td_move, at the bottom of the function (starting from if this_train_id and msg.get('from') in finals[crossing]) , only works intermittently : I'm getting some db updates but only ~1/3 or so of the log messages indicating that an update is wanted.
Can anyone suggest what the problem(s) could be please ?
async def main():
logger.info(f"db_updater starting {datetime.now().strftime('%H:%M:%S')}")
engine = create_async_engine(os.getenv('ASYNC_DB_URL'), future=True)
async with AsyncSession(engine) as session:
crossings, headcodes, lean_params, finals, active_trains, train_ids, berthtimes, hc_types = await get_db_data(logger) # noqa: E501
pool = await aioredis.create_redis_pool(('redis', 6379), db=0, password=os.getenv('REDIS_PW'), encoding='utf-8')
last_id = '$'
while True:
all_msgs = await pool.xread(['del_hc_s', 'xing_revised', 'all_td', 'add_hc_s'], latest_ids=[last_id, last_id, last_id, last_id]) # noqa: E501
for stream_name, msg_id, msg in all_msgs:
message = dict(msg)
crossing = message.get('crossing')
if stream_name == 'all_td':
await td_move(message, train_ids, active_trains, finals, lean_params, session)
elif stream_name == 'xing_revised':
await save_revised(message, lean_params[crossing], session)
async def save_revised(msg, params, session):
train_id = msg.get('train_id')
# today_class is a SQLA model class from declarative_base()
today_class = params['today_class']
rev_time = datetime.fromtimestamp(
int(msg.get('revised')))
stmt = update(today_class).where(today_class.train_id == train_id).\
values(xing_revised=rev_time).\
execution_options(synchronize_session="fetch")
await session.execute(stmt)
if msg.get('revised_ten') != 'X':
stmt2 = update(today_class).where(today_class.train_id == train_id).\
values(xing_revised_ten=rev_time).\
execution_options(synchronize_session="fetch")
await session.execute(stmt2)
await session.commit()
async def td_move(msg, train_ids, active_trains, finals, params, session):
crossing = msg.get('crossing')
descr = msg.get('descr')
if crossing:
this_train_id = [s for s in train_ids[crossing] if descr in s]
if this_train_id:
this_train_id = this_train_id[0]
else:
return
if this_train_id and active_trains[crossing].get(this_train_id) and (
is_within_minutes(30, active_trains[crossing].get(this_train_id))):
# Td_Ca_Cc is a SQLA model class from declarative_base()
td = Td_Ca_Cc(
msg_type=msg.get('msg_type'),
descr=msg.get('descr'),
traintype=active_trains[crossing].get(
this_train_id).get('train_type'),
from_berth=msg.get('from'),
to_berth=msg.get('to'),
tdtime=dt_from_timestamp(msg.get('time')),
seconds=0,
area_id=msg.get('area_id'),
updated=datetime.now(),
crossing=crossing
)
session.add(td)
if this_train_id and msg.get('from') in finals[crossing]:
today_class = params[crossing]['today_class']
stmt = update(today_class).where(today_class.train_id == this_train_id).\
values(xing_actual=datetime.now(), cancel_time='XXX').\
execution_options(synchronize_session="fetch")
await session.execute(stmt)
logger.info(f"{crossing} {msg.get('descr')} passed {datetime.now().strftime('%H:%M:%S')}")
await session.commit()
if __name__ == '__main__':
asyncio.run(main())
For future reference the problem was making a (sync) logging call. I removed this (and will add an async logger call), the modified code is now working fine

How to create a custom transformer with out any input column?

We have requirement where in we wanted to generate scores of our model with some random values in between 0-1.
To do that we wanted to have a custom transformer which will be generating random numbers with out any input fields.
So can we generate a transformer without input fields in mleap?
Like usually we do create as below:
import ml.combust.mleap.core.Model
import ml.combust.mleap.core.types._
case class RandomNumberModel() extends Model {
private val rnd = scala.util.Random
def apply(): Double = rnd.nextFloat
override def inputSchema: StructType = StructType("input" -> ScalarType.String).get
override def outputSchema: StructType = StructType("output" -> ScalarType.Double ).get
}
How to make it as input schema no necessary to put?
I have never tried that, but given how I achieved to have a custom transformer with multiple input fields ...
package org.apache.spark.ml.feature.mleap
import ml.combust.mleap.core.Model
import ml.combust.mleap.core.types._
import org.apache.spark.ml.linalg._
case class PropertyGroupAggregatorBaseModel (props: Array[String],
aggFunc: String) extends Model {
val outputSize = props.size
//having multiple inputs, you will have apply with a parameter Seq[Any]
def apply(features: Seq[Any]): Vector = {
val properties = features(0).asInstanceOf[Seq[String]]
val values = features(1).asInstanceOf[Seq[Double]]
val mapping = properties.zip(values)
val histogram = props.foldLeft(Array.empty[Double]){
(acc, property) =>
val newValues = mapping.filter(x => x._1 == property).map(x => x._2)
val newAggregate = aggFunc match {
case "sum" => newValues.sum.toDouble
case "count" => newValues.size.toDouble
case "avg" => (newValues.sum / Math.max(newValues.size, 1)).toDouble
}
acc :+ newAggregate
}
Vectors.dense(histogram)
}
override def inputSchema: StructType = {
//here you define the input
val inputFields = Seq(
StructField("input1" -> ListType(BasicType.String)),
StructField("input2" -> ListType(BasicType.Double))
)
StructType(inputFields).get
}
override def outputSchema: StructType = StructType(StructField("output" -> TensorType.Double(outputSize))).get
}
My suggestion would be, that the apply might already work for you. I guess if you define inputSchema as follows, it might work:
override def inputSchema: StructType = {
//here you define the input
val inputFields = Seq.empty[StructField]
StructType(inputFields).get
}

How can I create new map with new values but same keys from an existing map?

I have an existing map in Groovy.
I want to create a new map that has the same keys but different values in it.
Eg.:
def scores = ["vanilla":10, "chocolate":9, "papaya": 0]
//transformed into
def preference = ["vanilla":"love", "chocolate":"love", "papaya": "hate"]
Any way of doing it through some sort of closure like:
def preference = scores.collect {//something}
You can use collectEntries
scores.collectEntries { k, v ->
[ k, 'new value' ]
}
An alternative to using a map for the ranges would be to use a switch
def grade = { score ->
switch( score ) {
case 10..9: return 'love'
case 8..6: return 'like'
case 5..2: return 'meh'
case 1..0: return 'hate'
default : return 'ERR'
}
}
scores.collectEntries { k, v -> [ k, grade( v ) ] }
Nice, functional style solution(including your ranges, and easy to modify):
def scores = [vanilla:10, chocolate:9, papaya: 0]
// Store somewhere
def map = [(10..9):"love", (8..6):"like", (5..2):"meh", (1..0):"hate"]
def preference = scores.collectEntries { key, score -> [key, map.find { score in it.key }.value] }
// Output: [vanilla:love, chocolate:love, papaya:hate]
def scores = ["vanilla":10, "chocolate":9, "papaya": 0]
def preference = scores.collectEntries {key, value -> ["$key":(value > 5 ? "like" : "hate")]}
Then the result would be
[vanilla:like, chocolate:like, papaya:hate]
EDIT: If you want a map, then you should use collectEntries like tim_yates said.

Resources