fastapi's class as dependency override in unit testing - fastapi

Using the parameter supplied in the request body, I wanted to instantiate an object. For instance, if the request body contains type=push I desired to instantiate in this manner.
if channel_type.upper() == "PUSH":
return NotificationService(Push())
I am able to achieve this through classes as dependency of fastapi as following:
def factory(channel_type):
if channel_type.upper() == "PUSH":
return NotificationService(Push())
elif channel_type.upper() == "EMAIL":
return NotificationService(Email())
else:
return NotificationService(Sms())
class NotificationRequest:
def __init__(self, type: str = Body("push")):
self.service = factory(type)
class MessageRequest(BaseModel):
title: str
message: str
payload: Dict = {}
#router.post("/", response_model=dict)
def notify(message: MessageRequest, request: NotificationRequest = Depends()) -> Any:
request.service.send(users=get_users(), message=Message(
title=message.title, message=message.message, payload=message.payload
))
return {"success": True}
However, I wanted to override the NotificationRequest dependency in the unit test so that I could provide fake instances of real classes:
def get_factory(channel_type):
if channel_type.upper() == "PUSH":
return NotificationService(FakePush())
elif channel_type.upper() == "EMAIL":
return NotificationService(FakeEmail())
else:
return NotificationService(FakeSms())
I have tried following:
app.dependency_overrides[factory] = lambda: get_factory
which fails because the factory is not injected through the dependency.
How can we give a parameter to the body and inject a dependency into the function to produce a fake instance for unit testing?

Related

Handling TimeOut Exception in AsyncIO

import asyncio
import aiohttp
from time import perf_counter
import csv
path = "*******************"
domains = []
total_count=0
with open(path, 'r') as file:
csvreader = csv.reader(file)
for row in csvreader:
try:
website = row[4].split("//")[-1].split("www.")[-1].split('/')[0]
if website == "":
continue
domains.append(website)
except:
continue
sample = domains[0:50]
async def fetch(s, body):
async with s.post('https://****************', json=body) as r:
if r.status!= 200:
pass
enrich_response = await r.json()
#print(enrich_response)
employees = enrich_response['employees']
for employee in employees:
if(employee['job_title'] == "Owner"):
print(employee)
print("************************************************")
global total_count
total_count += 1
print("Total Count:", total_count)
continue
elif(employee['job_title'] == "CEO"):
print(employee)
print("***************************************************")
total_count+=1
print("Total Count:", total_count)
continue
else:
continue
async def fetch_all(s,bodies):
tasks = []
for body in bodies:
task = asyncio.create_task(fetch(s, body))
tasks.append(task)
res = await asyncio.gather(*tasks)
return res
async def main():
# apikeys = list(apikeysone.keys.values())
bodies = []
for domain in sample:
body = {
"api_key": "********************************",
"domain" : "{}".format(domain)
}
bodies.append(body)
async with aiohttp.ClientSession() as session:
data = await fetch_all(session, bodies)
print(data[0])
if __name__ == '__main__':
start = perf_counter()
try:
asyncio.run(main())
except Exception as e:
print(e)
pass
stop = perf_counter()
print("Time taken:", stop - start)
Hi!
I'm trying to connect to a scraping service provider using asyncio, instead of simple synchronous api calls.
But I get a TimeOut error. How could I use exception handling to wait a few seconds before retrying it once again? Or just skipping that task if it fails?
Thank you in advance fellow coder!
Tried adding to some places continue/pass
Try exploring asyncio.wait_for() function. It takes an awaitable and a timeout value. If task isn't completed before timeout value, it raises asyncio.exceptions.TimeoutError which you can handle in any way you want in except clause.
A typical example (from Python doc) is as follows:
async def eternity():
# Sleep for one hour
await asyncio.sleep(3600)
print('yay!')
async def main():
# Wait for at most 1 second
try:
await asyncio.wait_for(eternity(), timeout=1.0)
except TimeoutError:
print('timeout!')
asyncio.run(main())
# Expected output:
#
# timeout!

This Starlette Code works but only with GET, and I need POST

I was able to build from a sample Starlette example a piece of code that gets Basic Auth username and password, reads a header, and grabs the json body. But it only does so if I use "GET" instead of post, and I have not been able to figure out how to change the accepted method to POST. (The application I am trying to host for only uses POST. Is it a simple thing to get the POST method to work, or is this a rewrite?
from starlette.applications import Starlette
from starlette.authentication import requires
from starlette.authentication import (
AuthCredentials, AuthenticationBackend, AuthenticationError, SimpleUser
)
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.responses import (PlainTextResponse, JSONResponse)
from starlette.routing import Route
import base64
import binascii
class BasicAuthBackend(AuthenticationBackend):
async def authenticate(self, conn):
if "Authorization" not in conn.headers:
return
auth = conn.headers["Authorization"]
try:
scheme, credentials = auth.split()
if scheme.lower() != 'basic':
return
decoded = base64.b64decode(credentials).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error) as exc:
raise AuthenticationError('Invalid basic auth credentials')
username, _, password = decoded.partition(":")
global my_user
global my_pass
my_user = username
my_pass = password
# TODO: You'd want to verify the username and password here.
return AuthCredentials(["authenticated"]), SimpleUser(username)
async def homepage(request):
if request.user.is_authenticated:
body = await request.json()
return JSONResponse({"user": my_user, "password": my_pass, "header": request.headers['client_id']}, body )
return PlainTextResponse('Hello, you')
routes = [
Route("/testpath", endpoint=homepage)
]
middleware = [
Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())
]
app = Starlette(debug=True, routes=routes, middleware=middleware)
You need mention that your route accepts POST method.
async def homepage(request):
if request.user.is_authenticated:
body = await request.json()
return JSONResponse({"user": my_user, "password": my_pass, "header": request.headers['client_id']})
return PlainTextResponse('Hello, you')
routes = [
Route("/testpath", endpoint=homepage, methods=["POST"])
]

StoreModule.forRoot() - how to return object without additional key

I am wondering how can I return object of the same type as reducer function:
function storeReducer(
state = INITIAL_APPLICATION_STATE,
action: Actions
): ApplicationState {
switch (action.type) {
case LOAD_USER_THREADS_ACTION:
return handleLoadUserThreadsAction(state, action);
default:
return state;
}
}
I expect object of type ApplicationState, but with that approach:
StoreModule.forRoot({storeReducer})
I am getting object with key:
storeReducer:{ // object of type Application State}
I am expecting to get object (without additional storeReducer key):
{//object of type Application State}
Tried also StoreModule.forRoot(storeReducer) but then I am getting empty objects and it is not working.
The forRoot method on StoreModule expects and ActionReducerMap, not the result of your reducer.
I typically set mine up in a seperate file like this:
export interface IAppState {
aPieceOfState: IAPieceOfState;
}
export const reducers: ActionReducerMap<IAppState> = {
aPieceOfState: aPieceOfStateReducer
};
Then import this to app.module.ts and use it like:
StoreModule.forRoot(reducers)
Or you can put an assertion StoreModule.forRoot({storeReducer} as ActionReducerMap<IAppState>)

How to mock rest controller with ResponseEntity as return object?

Im making a unit test in a rest controller and this is the return:
return ResponseEntity.status(HttpStatus.OK).body(result);
Im getting this error:
Required request body is missing
This is my current test:
def "Signup"() {
given:
UserDto userDto = new UserDto(id: 1, password: "password123", username: "username123")
def personDto = new PersonDto(id: 1, user : userDto)
when: "signup url is hit"
def response = mockMvc.perform(post('/person/signup'))
then:
personService.signup(userDto) >> personDto
response.andExpect(status().isOk())
}
Any idea how to mock .body or how to add a body in the request. Thanks ::)
Add another expectation like:
response.andExpect(content().string(containsString('blah')))
Reference:
MockMvcResultMatchers.content()
ContentResultMatchers.string(org.hamcrest.Matcher<? super String> matcher)
import static groovyx.net.http.ContentType.JSON
import groovyx.net.http.RESTClient
import groovy.util.slurpersupport.GPathResult
import static groovyx.net.http.ContentType.URLENC
def accountId = "yourAccountId" // this is the number after http://basecamp.com when logged into the basecamp website e.g. http://basecamp.com/1234567
def userName = "basecampUserName"
def password = "basecampPassword"
def basecamp = new RESTClient( "https://basecamp.com/${accountId}/api/v1/".toString() )
basecamp.auth.basic userName, password
def response = basecamp.get(
path: "projects.json",
headers: ["User-Agent": "My basecamp application (myemail#domain.com)"]
)
println response.data.toString(2) // or you can return this value and do whatever you want
// post with body
def 'test post method'(){
given:
restClient .headers.Accept = 'application/json'
when:
def resp = restClient .post(path: 'path(ex:/api/list/',
query:[param1:'param1value',param2:'param2value'],
body: 'your json',
contentType:'application/json'
)
then:
resp.status == 200
}
}

Query graphite index.json for a specific sub-tree

I'm querying Graphite's index.json to get all the metrics. Is there an option to pass a root metric and get only a sub-tree? Something like:
http://<my.graphite>/metrics/index.json?query="my.metric.subtree"
That is not supported.
What you can do however is call /metrics/find recursively (call it again for each branch encountered)
Something like this:
#!/usr/bin/python
from __future__ import print_function
import requests
import json
import argparse
try:
from Queue import Queue
except:
from queue import Queue
from threading import Thread, Lock
import sys
import unicodedata
outLock = Lock()
def output(msg):
with outLock:
print(msg)
sys.stdout.flush()
class Walker(Thread):
def __init__(self, queue, url, user=None, password=None, seriesFrom=None, depth=None):
Thread.__init__(self)
self.queue = queue
self.url = url
self.user = user
self.password = password
self.seriesFrom = seriesFrom
self.depth = depth
def run(self):
while True:
branch = self.queue.get()
try:
branch[0].encode('ascii')
except Exception as e:
with outLock:
sys.stderr.write('found branch with invalid characters: ')
sys.stderr.write(unicodedata.normalize('NFKD', branch[0]).encode('utf-8','xmlcharrefreplace'))
sys.stderr.write('\n')
else:
if self.depth is not None and branch[1] == self.depth:
output(branch[0])
else:
self.walk(branch[0], branch[1])
self.queue.task_done()
def walk(self, prefix, depth):
payload = {
"query": (prefix + ".*") if prefix else '*',
"format": "treejson"
}
if self.seriesFrom:
payload['from']=self.seriesFrom
auth = None
if self.user is not None:
auth = (self.user, self.password)
r = requests.get(
self.url + '/metrics/find',
params=payload,
auth=auth,
)
if r.status_code != 200:
sys.stderr.write(r.text+'\n')
raise Exception(
'Error walking finding series: branch={branch} reason={reason}'
.format(branch=unicodedata.normalize('NFKD', prefix).encode('ascii','replace'), reason=r.reason)
)
metrics = r.json()
for metric in metrics:
try:
if metric['leaf']:
output(metric['id'])
else:
self.queue.put((metric['id'], depth+1))
except Exception as e:
output(metric)
raise e
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--url", help="Graphite URL", required=True)
parser.add_argument("--prefix", help="Metrics prefix", required=False, default='')
parser.add_argument("--user", help="Basic Auth username", required=False)
parser.add_argument("--password", help="Basic Auth password", required=False)
parser.add_argument("--concurrency", help="concurrency", default=8, required=False, type=int)
parser.add_argument("--from", dest='seriesFrom', help="only get series that have been active since this time", required=False)
parser.add_argument("--depth", type=int, help="maximum depth to traverse. If set, the branches at the depth will be printed", required=False)
args = parser.parse_args()
url = args.url
prefix = args.prefix
user = args.user
password = args.password
concurrency = args.concurrency
seriesFrom = args.seriesFrom
depth = args.depth
queue = Queue()
for x in range(concurrency):
worker = Walker(queue, url, user, password, seriesFrom, depth)
worker.daemon = True
worker.start()
queue.put((prefix, 0))
queue.join()
Note: this code comes from: https://github.com/grafana/cloud-graphite-scripts/blob/master/query/walk_metrics.py

Resources